From fc9770e6041e24d3c9baccac25409c0e2f2244f6 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Tue, 21 Jan 2025 20:30:33 -0500 Subject: [PATCH] feat: add support for OAuth PATs Instead of configuring and using OAuth, an app can specify that it supports using personal access tokens. If this is the case, then Obot will pass an extra environment variable to the oauth credential tool to indicate which integrations support tokens. If the oauth2 credential tool should prompt the user for a token instead of using OAuth, then Obot will not pass the environment variables that feed the URLs to the tool. A side effect of this change is that OAuth apps no longer default to global. Signed-off-by: Donnie Adams --- apiclient/types/oauthapp.go | 2 +- apiclient/types/zz_generated.deepcopy.go | 5 --- pkg/api/handlers/agent.go | 4 +- .../handlers/oauthapp/oauthapplogin.go | 2 +- pkg/controller/handlers/toolinfo/toolinfo.go | 4 +- pkg/gateway/types/oauth_apps.go | 2 +- pkg/render/render.go | 39 ++++++++++++++----- pkg/render/tool.go | 14 +++---- pkg/render/workflow.go | 24 ++++++++---- 9 files changed, 60 insertions(+), 36 deletions(-) diff --git a/apiclient/types/oauthapp.go b/apiclient/types/oauthapp.go index ebdab3715..f2d06b7d2 100644 --- a/apiclient/types/oauthapp.go +++ b/apiclient/types/oauthapp.go @@ -37,7 +37,7 @@ type OAuthAppManifest struct { // This field is required, it correlates to the integration name in the gptscript oauth cred tool Integration string `json:"integration,omitempty"` // Global indicates if the OAuth app is globally applied to all agents. - Global *bool `json:"global,omitempty"` + Global bool `json:"global,omitempty"` // This field is only used by Salesforce InstanceURL string `json:"instanceURL,omitempty"` } diff --git a/apiclient/types/zz_generated.deepcopy.go b/apiclient/types/zz_generated.deepcopy.go index 74cf02295..a90a56f84 100644 --- a/apiclient/types/zz_generated.deepcopy.go +++ b/apiclient/types/zz_generated.deepcopy.go @@ -1195,11 +1195,6 @@ func (in *OAuthAppLoginAuthStatus) DeepCopy() *OAuthAppLoginAuthStatus { func (in *OAuthAppManifest) DeepCopyInto(out *OAuthAppManifest) { *out = *in in.Metadata.DeepCopyInto(&out.Metadata) - if in.Global != nil { - in, out := &in.Global, &out.Global - *out = new(bool) - **out = **in - } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OAuthAppManifest. diff --git a/pkg/api/handlers/agent.go b/pkg/api/handlers/agent.go index b29ab7255..4ddc595a6 100644 --- a/pkg/api/handlers/agent.go +++ b/pkg/api/handlers/agent.go @@ -918,7 +918,7 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I var toolRef v1.ToolReference for _, tool := range tools { - if strings.ContainsAny(tool, "./") { + if render.IsExternalTool(tool) { prg, err := gClient.LoadFile(ctx, tool) if err != nil { return nil, err @@ -965,7 +965,7 @@ func removeToolCredentials(ctx context.Context, client kclient.Client, gClient * credentialNames []string ) for _, tool := range tools { - if strings.ContainsAny(tool, "./") { + if render.IsExternalTool(tool) { prg, err := gClient.LoadFile(ctx, tool) if err != nil { errs = append(errs, err) diff --git a/pkg/controller/handlers/oauthapp/oauthapplogin.go b/pkg/controller/handlers/oauthapp/oauthapplogin.go index 76b75a9d7..d5f8ac60b 100644 --- a/pkg/controller/handlers/oauthapp/oauthapplogin.go +++ b/pkg/controller/handlers/oauthapp/oauthapplogin.go @@ -53,7 +53,7 @@ func (h *LoginHandler) RunTool(req router.Request, _ router.Response) error { return err } - oauthAppEnv, err := render.OAuthAppEnv(req.Ctx, req.Client, login.Spec.OAuthApps, login.Namespace, h.serverURL) + oauthAppEnv, err := render.OAuthAppEnv(req.Ctx, req.Client, login.Spec.OAuthApps, login.Namespace, h.serverURL, login.Spec.PATSupportedIntegrations) if err != nil { return err } diff --git a/pkg/controller/handlers/toolinfo/toolinfo.go b/pkg/controller/handlers/toolinfo/toolinfo.go index cb9e199c9..282efbc4f 100644 --- a/pkg/controller/handlers/toolinfo/toolinfo.go +++ b/pkg/controller/handlers/toolinfo/toolinfo.go @@ -3,12 +3,12 @@ package toolinfo import ( "context" "fmt" - "strings" "github.com/gptscript-ai/go-gptscript" "github.com/obot-platform/nah/pkg/router" "github.com/obot-platform/obot/apiclient/types" "github.com/obot-platform/obot/pkg/controller/creds" + "github.com/obot-platform/obot/pkg/render" v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" apierror "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/util/sets" @@ -57,7 +57,7 @@ func (h *Handler) SetToolInfoStatus(req router.Request, resp router.Response) (e credNames []string ) for _, tool := range tools { - if strings.ContainsAny(tool, "/.") { + if render.IsExternalTool(tool) { credNames, err = h.credentialNamesForNonToolReferences(req.Ctx, tool) if err != nil { return err diff --git a/pkg/gateway/types/oauth_apps.go b/pkg/gateway/types/oauth_apps.go index 89cafbbb0..c8735d68f 100644 --- a/pkg/gateway/types/oauth_apps.go +++ b/pkg/gateway/types/oauth_apps.go @@ -166,7 +166,7 @@ func MergeOAuthAppManifests(r, other types.OAuthAppManifest) types.OAuthAppManif if other.OptionalScope != "" { retVal.OptionalScope = other.OptionalScope } - if other.Global != nil { + if other.Global { retVal.Global = other.Global } diff --git a/pkg/render/render.go b/pkg/render/render.go index 88524af96..deb648ae8 100644 --- a/pkg/render/render.go +++ b/pkg/render/render.go @@ -74,11 +74,11 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU } if opts.Thread != nil { - for _, tool := range opts.Thread.Spec.Manifest.Tools { - if !added && tool == knowledgeToolName { + for _, t := range opts.Thread.Spec.Manifest.Tools { + if !added && t == knowledgeToolName { continue } - name, tools, err := Tool(ctx, db, agent.Namespace, tool) + name, _, tools, err := tool(ctx, db, agent.Namespace, t) if err != nil { return nil, nil, err } @@ -99,18 +99,33 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU } } - for _, tool := range agent.Spec.Manifest.Tools { - if !added && tool == knowledgeToolName { + patSupportedIntegrations := make(map[string]struct{}) + for _, t := range agent.Spec.Manifest.Tools { + if !added && t == knowledgeToolName { continue } - name, tools, err := Tool(ctx, db, agent.Namespace, tool) + name, metadata, tools, err := tool(ctx, db, agent.Namespace, t) if err != nil { return nil, nil, err } if name != "" { mainTool.Tools = append(mainTool.Tools, name) } - otherTools = append(otherTools, tools...) + + if metadata["oauthPATSupported"] == "true" { + if integration := metadata["oauth"]; integration != "" { + patSupportedIntegrations[integration] = struct{}{} + } + } + + for _, t := range tools { + if t.MetaData["oauthPATSupported"] == "true" { + if integration := t.MetaData["oauth"]; integration != "" { + patSupportedIntegrations[integration] = struct{}{} + } + } + otherTools = append(otherTools, t) + } } for _, tool := range agent.Spec.SystemTools { @@ -134,7 +149,7 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU return nil, nil, err } - oauthEnv, err := OAuthAppEnv(ctx, db, agent.Spec.Manifest.OAuthApps, agent.Namespace, oauthServerURL) + oauthEnv, err := OAuthAppEnv(ctx, db, agent.Spec.Manifest.OAuthApps, agent.Namespace, oauthServerURL, slices.Collect(maps.Keys(patSupportedIntegrations))) if err != nil { return nil, nil, err } @@ -144,7 +159,7 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU return append([]gptscript.ToolDef{mainTool}, otherTools...), extraEnv, nil } -func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string, namespace, serverURL string) (extraEnv []string, _ error) { +func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string, namespace, serverURL string, patIntegrations []string) (extraEnv []string, _ error) { apps, err := oauthAppsByName(ctx, db, namespace) if err != nil { return nil, err @@ -153,7 +168,7 @@ func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string, activeIntegrations := map[string]v1.OAuthApp{} for _, name := range slices.Sorted(maps.Keys(apps)) { app := apps[name] - if app.Spec.Manifest.Global == nil || !*app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Integration == "" { + if !app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Integration == "" { continue } activeIntegrations[app.Spec.Manifest.Integration] = app @@ -184,6 +199,10 @@ func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string, fmt.Sprintf("GPTSCRIPT_OAUTH_%s_TOKEN_URL=%s", integrationEnv, v1.OAuthAppGetTokenURL(serverURL))) } + if len(patIntegrations) > 0 { + extraEnv = append(extraEnv, fmt.Sprintf("GPTSCRIPT_OAUTH_PAT_INTEGRATIONS=%s", strings.Join(patIntegrations, ","))) + } + return extraEnv, nil } diff --git a/pkg/render/tool.go b/pkg/render/tool.go index d6fa680be..afea1e086 100644 --- a/pkg/render/tool.go +++ b/pkg/render/tool.go @@ -93,25 +93,25 @@ END INSTRUCTIONS: TOOL %q`, tool.Spec.Manifest.Name, tool.Spec.Manifest.Context, return toolDefs, nil } -func Tool(ctx context.Context, c client.Client, ns, name string) (_ string, toolDefs []gptscript.ToolDef, _ error) { +func tool(ctx context.Context, c client.Client, ns, name string) (string, map[string]string, []gptscript.ToolDef, error) { if !system.IsToolID(name) { - name, err := ResolveToolReference(ctx, c, types.ToolReferenceTypeTool, ns, name) - return name, nil, err + name, metadata, err := resolveToolReferenceWithMetadata(ctx, c, types.ToolReferenceTypeTool, ns, name) + return name, metadata, nil, err } var tool v1.Tool if err := c.Get(ctx, router.Key(ns, name), &tool); err != nil { - return name, nil, err + return name, nil, nil, err } toolDefs, err := CustomTool(ctx, c, tool) if err != nil { - return "", nil, err + return "", nil, nil, err } if len(toolDefs) == 0 { - return "", toolDefs, nil + return "", nil, toolDefs, nil } - return toolDefs[0].Name, toolDefs, nil + return toolDefs[0].Name, nil, toolDefs, nil } diff --git a/pkg/render/workflow.go b/pkg/render/workflow.go index 854bc5732..396091a66 100644 --- a/pkg/render/workflow.go +++ b/pkg/render/workflow.go @@ -29,26 +29,36 @@ func IsExternalTool(tool string) bool { } func ResolveToolReference(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, error) { + name, _, err := resolveToolReferenceWithMetadata(ctx, c, toolRefType, ns, name) + return name, err +} + +func resolveToolReferenceWithMetadata(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, map[string]string, error) { if IsExternalTool(name) { - return name, nil + return name, nil, nil } var tool v1.ToolReference if err := c.Get(ctx, router.Key(ns, name), &tool); apierror.IsNotFound(err) { - return name, nil + return name, nil, nil } else if err != nil { - return "", err + return "", nil, err + } + + var metadata map[string]string + if tool.Status.Tool != nil { + metadata = tool.Status.Tool.Metadata } if toolRefType != "" && tool.Spec.Type != toolRefType { - return name, fmt.Errorf("tool reference %s is not of type %s", name, toolRefType) + return name, metadata, fmt.Errorf("tool reference %s is not of type %s", name, toolRefType) } if tool.Status.Reference == "" { - return "", fmt.Errorf("tool reference %s has no reference", name) + return "", nil, fmt.Errorf("tool reference %s has no reference", name) } if toolRefType == types.ToolReferenceTypeTool { - return fmt.Sprintf("%s as %s", tool.Status.Reference, name), nil + return fmt.Sprintf("%s as %s", tool.Status.Reference, name), metadata, nil } - return tool.Status.Reference, nil + return tool.Status.Reference, metadata, nil } func Workflow(ctx context.Context, c kclient.Client, wf *v1.Workflow, opts WorkflowOptions) (*v1.Agent, error) {