Skip to content

Commit

Permalink
feat: add support for OAuth PATs (#1387)
Browse files Browse the repository at this point in the history
feat: add support for OAuth prompt tokens

Instead of configuring and using OAuth, an app can specify that it
supports using tokens. If this is the case, then the tool's credential
should list some number of "prompt_tokens" and, optionally,
"prompt_vars" and the user will be prompted for those. If the oauth2
credential tool should prompt the user for a token instead of using
OAuth, then Obot will not pass the environment variables that feed
the URLs to the tool.

A side effect of this change is that OAuth apps no longer default to
global.

Signed-off-by: Donnie Adams <[email protected]>
Co-authored-by: Ivy <[email protected]>
  • Loading branch information
thedadams and ivyjeong13 authored Jan 30, 2025
1 parent b350b55 commit 8761cb8
Show file tree
Hide file tree
Showing 31 changed files with 575 additions and 361 deletions.
4 changes: 2 additions & 2 deletions apiclient/types/oauthapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ type OAuthAppManifest struct {
// This field is optional for HubSpot OAuth apps.
OptionalScope string `json:"optionalScope,omitempty"`
// This field is required, it correlates to the integration name in the gptscript oauth cred tool
Integration string `json:"integration,omitempty"`
Alias string `json:"alias,omitempty"`
// Global indicates if the OAuth app is globally applied to all agents.
Global *bool `json:"global,omitempty"`
Global bool `json:"global,omitempty"`
// This field is only used by Salesforce
InstanceURL string `json:"instanceURL,omitempty"`
}
Expand Down
5 changes: 0 additions & 5 deletions apiclient/types/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 16 additions & 9 deletions pkg/api/handlers/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,15 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
return req.WriteCreated(resp)
}

credentialTools, err := v1.CredentialTools(req.Context(), req.Storage, req.Namespace(), ref)
if err != nil {
return err
var toolReference v1.ToolReference
if err := req.Get(&toolReference, ref); err != nil {
return fmt.Errorf("failed to get tool reference %v", ref)
}
if toolReference.Status.Tool == nil {
return types.NewErrHttp(http.StatusTooEarly, "tool reference is not ready yet")
}

if len(credentialTools) == 0 {
if len(toolReference.Status.Tool.Credentials) == 0 {
// The only way to get here is if the controller hasn't set the field yet.
if agent.Status.AuthStatus == nil {
agent.Status.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus)
Expand All @@ -754,6 +757,10 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
return req.WriteCreated(resp)
}

if _, ok := toolReference.Status.Tool.Metadata["oauth"]; !ok {
return types.NewErrBadRequest("tool reference %q does not have oauth metadata", ref)
}

oauthLogin := &v1.OAuthAppLogin{
ObjectMeta: metav1.ObjectMeta{
Name: system.OAuthAppLoginPrefix + agent.Name + ref,
Expand All @@ -762,15 +769,15 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
Spec: v1.OAuthAppLoginSpec{
CredentialContext: agent.Name,
ToolReference: ref,
OAuthApps: agent.Spec.Manifest.OAuthApps,
OAuthApps: []string{toolReference.Status.Tool.Metadata["oauth"]},
},
}

if err = req.Delete(oauthLogin); err != nil {
if err := req.Delete(oauthLogin); err != nil {
return err
}

oauthLogin, err = wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
oauthLogin, err := wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
return obj.Status.External.Authenticated || obj.Status.External.Error != "" || obj.Status.External.URL != "", nil
}, wait.Option{
Create: true,
Expand Down Expand Up @@ -918,7 +925,7 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I

var toolRef v1.ToolReference
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
if render.IsExternalTool(tool) {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
return nil, err
Expand Down Expand Up @@ -965,7 +972,7 @@ func removeToolCredentials(ctx context.Context, client kclient.Client, gClient *
credentialNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
if render.IsExternalTool(tool) {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
errs = append(errs, err)
Expand Down
18 changes: 11 additions & 7 deletions pkg/api/handlers/workflows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"github.com/gptscript-ai/go-gptscript"
Expand Down Expand Up @@ -368,12 +369,15 @@ func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) er
return req.WriteCreated(resp)
}

credentialTools, err := v1.CredentialTools(req.Context(), req.Storage, req.Namespace(), ref)
if err != nil {
return err
var toolReference v1.ToolReference
if err := req.Get(&toolReference, ref); err != nil {
return fmt.Errorf("failed to get tool reference %v", ref)
}
if toolReference.Status.Tool == nil {
return types.NewErrHttp(http.StatusTooEarly, "tool reference is not ready yet")
}

if len(credentialTools) == 0 {
if len(toolReference.Status.Tool.Credentials) == 0 {
// The only way to get here is if the controller hasn't set the field yet.
if wf.Status.AuthStatus == nil {
wf.Status.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus)
Expand All @@ -397,15 +401,15 @@ func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) er
Spec: v1.OAuthAppLoginSpec{
CredentialContext: wf.Name,
ToolReference: ref,
OAuthApps: wf.Spec.Manifest.OAuthApps,
OAuthApps: []string{toolReference.Status.Tool.Metadata["oauth"]},
},
}

if err = req.Delete(oauthLogin); err != nil {
if err := req.Delete(oauthLogin); err != nil {
return err
}

oauthLogin, err = wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
oauthLogin, err := wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
return obj.Status.External.Authenticated || obj.Status.External.Error != "" || obj.Status.External.URL != "", nil
}, wait.Option{
Create: true,
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/handlers/toolinfo/toolinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package toolinfo
import (
"context"
"fmt"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/nah/pkg/router"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/controller/creds"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
apierror "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -57,7 +57,7 @@ func (h *Handler) SetToolInfoStatus(req router.Request, resp router.Response) (e
credNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "/.") {
if render.IsExternalTool(tool) {
credNames, err = h.credentialNamesForNonToolReferences(req.Ctx, tool)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions pkg/gateway/server/oauth_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ func (s *Server) createOAuthApp(apiContext api.Context) error {
var existingApps v1.OAuthAppList
if err := apiContext.Storage.List(apiContext.Context(), &existingApps, &kclient.ListOptions{
FieldSelector: fields.SelectorFromSet(selectors.RemoveEmpty(map[string]string{
"spec.manifest.integration": appManifest.Integration,
"spec.manifest.alias": appManifest.Alias,
})),
Namespace: apiContext.Namespace(),
}); err != nil {
return err
}

if len(existingApps.Items) > 0 {
return types2.NewErrHttp(http.StatusConflict, fmt.Sprintf("OAuth app with integration %s already exists", appManifest.Integration))
return types2.NewErrHttp(http.StatusConflict, fmt.Sprintf("OAuth app with alias %s already exists", appManifest.Alias))
}

app := v1.OAuthApp{
Expand Down
22 changes: 13 additions & 9 deletions pkg/gateway/types/oauth_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,23 @@ type OAuthAppTypeConfig struct {

func ValidateAndSetDefaultsOAuthAppManifest(r *types.OAuthAppManifest, create bool) error {
var errs []error
if r.Integration == "" {
errs = append(errs, fmt.Errorf("missing integration"))
} else if !alphaNumericRegexp.MatchString(r.Integration) {
errs = append(errs, fmt.Errorf("integration name can only contain alphanumeric characters and hyphens: %s", r.Integration))
if r.Alias == "" {
errs = append(errs, fmt.Errorf("missing alias"))
} else if !alphaNumericRegexp.MatchString(r.Alias) {
errs = append(errs, fmt.Errorf("alias name can only contain alphanumeric characters and hyphens: %s", r.Alias))
}

switch r.Type {
case types.OAuthAppTypeAtlassian:
r.AuthURL = AtlassianAuthorizeURL
r.TokenURL = AtlassianTokenURL
case types.OAuthAppTypeMicrosoft365:
r.AuthURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", r.TenantID)
r.TokenURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", r.TenantID)
tenantID := r.TenantID
if tenantID == "" {
tenantID = "common"
}
r.AuthURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenantID)
r.TokenURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenantID)
case types.OAuthAppTypeSlack:
r.AuthURL = SlackAuthorizeURL
r.TokenURL = SlackTokenURL
Expand Down Expand Up @@ -163,16 +167,16 @@ func MergeOAuthAppManifests(r, other types.OAuthAppManifest) types.OAuthAppManif
if other.Name != "" {
retVal.Name = other.Name
}
if other.Integration != "" {
retVal.Integration = other.Integration
if other.Alias != "" {
retVal.Alias = other.Alias
}
if other.AppID != "" {
retVal.AppID = other.AppID
}
if other.OptionalScope != "" {
retVal.OptionalScope = other.OptionalScope
}
if other.Global != nil {
if other.Global {
retVal.Global = other.Global
}

Expand Down
27 changes: 14 additions & 13 deletions pkg/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
}

if opts.Thread != nil {
for _, tool := range opts.Thread.Spec.Manifest.Tools {
if !added && tool == knowledgeToolName {
for _, t := range opts.Thread.Spec.Manifest.Tools {
if !added && t == knowledgeToolName {
continue
}
name, tools, err := Tool(ctx, db, agent.Namespace, tool)
name, tools, err := tool(ctx, db, agent.Namespace, t)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -107,17 +107,18 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
}
}

for _, tool := range agent.Spec.Manifest.Tools {
if !added && tool == knowledgeToolName {
for _, t := range agent.Spec.Manifest.Tools {
if !added && t == knowledgeToolName {
continue
}
name, tools, err := Tool(ctx, db, agent.Namespace, tool)
name, tools, err := tool(ctx, db, agent.Namespace, t)
if err != nil {
return nil, nil, err
}
if name != "" {
mainTool.Tools = append(mainTool.Tools, name)
}

otherTools = append(otherTools, tools...)
}

Expand Down Expand Up @@ -161,30 +162,30 @@ func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string,
activeIntegrations := map[string]v1.OAuthApp{}
for _, name := range slices.Sorted(maps.Keys(apps)) {
app := apps[name]
if app.Spec.Manifest.Global == nil || !*app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Integration == "" {
if !app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Alias == "" {
continue
}
activeIntegrations[app.Spec.Manifest.Integration] = app
activeIntegrations[app.Spec.Manifest.Alias] = app
}

for _, appRef := range oauthAppNames {
app, ok := apps[appRef]
if !ok {
return nil, fmt.Errorf("oauth app %s not found", appRef)
}
if app.Spec.Manifest.Integration == "" {
if app.Spec.Manifest.Alias == "" {
return nil, fmt.Errorf("oauth app %s has no integration name", app.Name)
}
if app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" {
return nil, fmt.Errorf("oauth app %s has no client id or secret", app.Name)
}

activeIntegrations[app.Spec.Manifest.Integration] = app
activeIntegrations[app.Spec.Manifest.Alias] = app
}

for _, integration := range slices.Sorted(maps.Keys(activeIntegrations)) {
app := activeIntegrations[integration]
integrationEnv := strings.ReplaceAll(strings.ToUpper(app.Spec.Manifest.Integration), "-", "_")
integrationEnv := strings.ReplaceAll(strings.ToUpper(app.Spec.Manifest.Alias), "-", "_")

extraEnv = append(extraEnv,
fmt.Sprintf("GPTSCRIPT_OAUTH_%s_AUTH_URL=%s", integrationEnv, app.AuthorizeURL(serverURL)),
Expand Down Expand Up @@ -351,8 +352,8 @@ func oauthAppsByName(ctx context.Context, c kclient.Client, namespace string) (m
}

for _, app := range apps.Items {
if app.Spec.Manifest.Integration != "" {
result[app.Spec.Manifest.Integration] = app
if app.Spec.Manifest.Alias != "" {
result[app.Spec.Manifest.Alias] = app
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/render/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ END INSTRUCTIONS: TOOL %q`, tool.Spec.Manifest.Name, tool.Spec.Manifest.Context,
return toolDefs, nil
}

func Tool(ctx context.Context, c client.Client, ns, name string) (_ string, toolDefs []gptscript.ToolDef, _ error) {
func tool(ctx context.Context, c client.Client, ns, name string) (string, []gptscript.ToolDef, error) {
if !system.IsToolID(name) {
name, err := ResolveToolReference(ctx, c, types.ToolReferenceTypeTool, ns, name)
name, err := resolveToolReferenceWithMetadata(ctx, c, types.ToolReferenceTypeTool, ns, name)
return name, nil, err
}

Expand Down
6 changes: 6 additions & 0 deletions pkg/render/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func IsExternalTool(tool string) bool {
}

func ResolveToolReference(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, error) {
name, err := resolveToolReferenceWithMetadata(ctx, c, toolRefType, ns, name)
return name, err
}

func resolveToolReferenceWithMetadata(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, error) {
if IsExternalTool(name) {
return name, nil
}
Expand All @@ -39,6 +44,7 @@ func ResolveToolReference(ctx context.Context, c kclient.Client, toolRefType typ
} else if err != nil {
return "", err
}

if toolRefType != "" && tool.Spec.Type != toolRefType {
return name, fmt.Errorf("tool reference %s is not of type %s", name, toolRefType)
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/storage/apis/obot.obot.ai/v1/oauthapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type OAuthApp struct {
}

func (r *OAuthApp) GetAliasName() string {
return r.Spec.Manifest.Integration
return r.Spec.Manifest.Alias
}

func (r *OAuthApp) SetAssigned(bool) {}
Expand All @@ -48,32 +48,32 @@ func (r *OAuthApp) Has(field string) bool {
func (r *OAuthApp) Get(field string) string {
if r != nil {
switch field {
case "spec.manifest.integration":
return r.Spec.Manifest.Integration
case "spec.manifest.alias":
return r.Spec.Manifest.Alias
}
}

return ""
}

func (r *OAuthApp) FieldNames() []string {
return []string{"spec.manifest.integration"}
return []string{"spec.manifest.alias"}
}

func (r *OAuthApp) RedirectURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/callback/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/callback/%s", baseURL, r.Spec.Manifest.Alias)
}

func OAuthAppGetTokenURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/get-token", baseURL)
}

func (r *OAuthApp) AuthorizeURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/authorize/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/authorize/%s", baseURL, r.Spec.Manifest.Alias)
}

func (r *OAuthApp) RefreshURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/refresh/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/refresh/%s", baseURL, r.Spec.Manifest.Alias)
}

func (r *OAuthApp) DeleteRefs() []Ref {
Expand Down
Loading

0 comments on commit 8761cb8

Please sign in to comment.