diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index d0481ec8..a3454dd5 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -276,7 +276,10 @@ func (r *GPTScript) listModels(ctx context.Context, gptScript *gptscript.GPTScri if err != nil { return err } - fmt.Println(strings.Join(models, "\n")) + + for _, model := range models { + fmt.Println(model.ID) + } return nil } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index dfb1771a..5a7229a1 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -10,6 +10,7 @@ import ( "slices" "strings" + openai2 "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/config" @@ -275,7 +276,7 @@ func (g *GPTScript) ListTools(_ context.Context, prg types.Program) []types.Tool return prg.TopLevelTools() } -func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]string, error) { +func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]openai2.Model, error) { return g.Registry.ListModels(ctx, providers...) } diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go index 09fe1dce..d53d96b9 100644 --- a/pkg/llm/registry.go +++ b/pkg/llm/registry.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/google/uuid" + openai2 "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/openai" "github.com/gptscript-ai/gptscript/pkg/remote" @@ -16,7 +17,7 @@ import ( type Client interface { Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) - ListModels(ctx context.Context, providers ...string) (result []string, _ error) + ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) Supports(ctx context.Context, modelName string) (bool, error) } @@ -38,7 +39,7 @@ func (r *Registry) AddClient(client Client) error { return nil } -func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { +func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) { for _, v := range r.clients { models, err := v.ListModels(ctx, providers...) if err != nil { @@ -46,7 +47,9 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result } result = append(result, models...) } - sort.Strings(result) + sort.Slice(result, func(i, j int) bool { + return result[i].ID < result[j].ID + }) return result, nil } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index dea234a9..db911962 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -157,10 +157,15 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { return false, InvalidAuthError{} } - return slices.Contains(models, modelName), nil + for _, model := range models { + if model.ID == modelName { + return true, nil + } + } + return false, nil } -func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { +func (c *Client) ListModels(ctx context.Context, providers ...string) ([]openai.Model, error) { // Only serve if providers is empty or "" is in the list if len(providers) != 0 && !slices.Contains(providers, "") { return nil, nil @@ -179,11 +184,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] if err != nil { return nil, err } - for _, model := range models.Models { - result = append(result, model.ID) - } - sort.Strings(result) - return result, nil + sort.Slice(models.Models, func(i, j int) bool { + return models.Models[i].ID < models.Models[j].ID + }) + return models.Models, nil } func (c *Client) cacheKey(request openai.ChatCompletionRequest) any { diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 5542372b..93f612ef 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + openai2 "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" @@ -62,7 +63,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return client.Call(ctx, messageRequest, env, status) } -func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { +func (c *Client) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) { for _, provider := range providers { client, err := c.load(ctx, provider) if err != nil { @@ -72,12 +73,16 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] if err != nil { return nil, err } - for _, model := range models { - result = append(result, model+" from "+provider) + for i := range models { + models[i].ID = fmt.Sprintf("%s from %s", models[i].ID, provider) } + + result = append(result, models...) } - sort.Strings(result) + sort.Slice(result, func(i, j int) bool { + return result[i].ID < result[j].ID + }) return } diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 73bf5d58..801227a1 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -145,7 +145,7 @@ func (s *server) listModels(w http.ResponseWriter, r *http.Request) { return } - writeResponse(logger, w, map[string]any{"stdout": strings.Join(out, "\n")}) + writeResponse(logger, w, map[string]any{"stdout": out}) } // execHandler is a general handler for executing tools with gptscript. This is mainly responsible for parsing the request body.