Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for OAuth PATs #1387

Merged
merged 11 commits into from
Jan 30, 2025
4 changes: 2 additions & 2 deletions apiclient/types/oauthapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ type OAuthAppManifest struct {
// This field is optional for HubSpot OAuth apps.
OptionalScope string `json:"optionalScope,omitempty"`
// This field is required, it correlates to the integration name in the gptscript oauth cred tool
Integration string `json:"integration,omitempty"`
Alias string `json:"alias,omitempty"`
// Global indicates if the OAuth app is globally applied to all agents.
Global *bool `json:"global,omitempty"`
Global bool `json:"global,omitempty"`
// This field is only used by Salesforce
InstanceURL string `json:"instanceURL,omitempty"`
}
Expand Down
5 changes: 0 additions & 5 deletions apiclient/types/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 16 additions & 9 deletions pkg/api/handlers/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,15 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
return req.WriteCreated(resp)
}

credentialTools, err := v1.CredentialTools(req.Context(), req.Storage, req.Namespace(), ref)
if err != nil {
return err
var toolReference v1.ToolReference
if err := req.Get(&toolReference, ref); err != nil {
return fmt.Errorf("failed to get tool reference %v", ref)
}
if toolReference.Status.Tool == nil {
return types.NewErrHttp(http.StatusTooEarly, "tool reference is not ready yet")
}

if len(credentialTools) == 0 {
if len(toolReference.Status.Tool.Credentials) == 0 {
// The only way to get here is if the controller hasn't set the field yet.
if agent.Status.AuthStatus == nil {
agent.Status.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus)
Expand All @@ -754,6 +757,10 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
return req.WriteCreated(resp)
}

if _, ok := toolReference.Status.Tool.Metadata["oauth"]; !ok {
return types.NewErrBadRequest("tool reference %q does not have oauth metadata", ref)
}

oauthLogin := &v1.OAuthAppLogin{
ObjectMeta: metav1.ObjectMeta{
Name: system.OAuthAppLoginPrefix + agent.Name + ref,
Expand All @@ -762,15 +769,15 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
Spec: v1.OAuthAppLoginSpec{
CredentialContext: agent.Name,
ToolReference: ref,
OAuthApps: agent.Spec.Manifest.OAuthApps,
OAuthApps: []string{toolReference.Status.Tool.Metadata["oauth"]},
},
}

if err = req.Delete(oauthLogin); err != nil {
if err := req.Delete(oauthLogin); err != nil {
return err
}

oauthLogin, err = wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
oauthLogin, err := wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
return obj.Status.External.Authenticated || obj.Status.External.Error != "" || obj.Status.External.URL != "", nil
}, wait.Option{
Create: true,
Expand Down Expand Up @@ -918,7 +925,7 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I

var toolRef v1.ToolReference
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
if render.IsExternalTool(tool) {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
return nil, err
Expand Down Expand Up @@ -965,7 +972,7 @@ func removeToolCredentials(ctx context.Context, client kclient.Client, gClient *
credentialNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
if render.IsExternalTool(tool) {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
errs = append(errs, err)
Expand Down
18 changes: 11 additions & 7 deletions pkg/api/handlers/workflows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"github.com/gptscript-ai/go-gptscript"
Expand Down Expand Up @@ -368,12 +369,15 @@ func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) er
return req.WriteCreated(resp)
}

credentialTools, err := v1.CredentialTools(req.Context(), req.Storage, req.Namespace(), ref)
if err != nil {
return err
var toolReference v1.ToolReference
if err := req.Get(&toolReference, ref); err != nil {
return fmt.Errorf("failed to get tool reference %v", ref)
}
if toolReference.Status.Tool == nil {
return types.NewErrHttp(http.StatusTooEarly, "tool reference is not ready yet")
}

if len(credentialTools) == 0 {
if len(toolReference.Status.Tool.Credentials) == 0 {
// The only way to get here is if the controller hasn't set the field yet.
if wf.Status.AuthStatus == nil {
wf.Status.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus)
Expand All @@ -397,15 +401,15 @@ func (a *WorkflowHandler) EnsureCredentialForKnowledgeSource(req api.Context) er
Spec: v1.OAuthAppLoginSpec{
CredentialContext: wf.Name,
ToolReference: ref,
OAuthApps: wf.Spec.Manifest.OAuthApps,
OAuthApps: []string{toolReference.Status.Tool.Metadata["oauth"]},
},
}

if err = req.Delete(oauthLogin); err != nil {
if err := req.Delete(oauthLogin); err != nil {
return err
}

oauthLogin, err = wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
oauthLogin, err := wait.For(req.Context(), req.Storage, oauthLogin, func(obj *v1.OAuthAppLogin) (bool, error) {
return obj.Status.External.Authenticated || obj.Status.External.Error != "" || obj.Status.External.URL != "", nil
}, wait.Option{
Create: true,
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/handlers/toolinfo/toolinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package toolinfo
import (
"context"
"fmt"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/nah/pkg/router"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/controller/creds"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
apierror "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -57,7 +57,7 @@ func (h *Handler) SetToolInfoStatus(req router.Request, resp router.Response) (e
credNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "/.") {
if render.IsExternalTool(tool) {
credNames, err = h.credentialNamesForNonToolReferences(req.Ctx, tool)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions pkg/gateway/server/oauth_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ func (s *Server) createOAuthApp(apiContext api.Context) error {
var existingApps v1.OAuthAppList
if err := apiContext.Storage.List(apiContext.Context(), &existingApps, &kclient.ListOptions{
FieldSelector: fields.SelectorFromSet(selectors.RemoveEmpty(map[string]string{
"spec.manifest.integration": appManifest.Integration,
"spec.manifest.alias": appManifest.Alias,
})),
Namespace: apiContext.Namespace(),
}); err != nil {
return err
}

if len(existingApps.Items) > 0 {
return types2.NewErrHttp(http.StatusConflict, fmt.Sprintf("OAuth app with integration %s already exists", appManifest.Integration))
return types2.NewErrHttp(http.StatusConflict, fmt.Sprintf("OAuth app with alias %s already exists", appManifest.Alias))
}

app := v1.OAuthApp{
Expand Down
22 changes: 13 additions & 9 deletions pkg/gateway/types/oauth_apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,23 @@ type OAuthAppTypeConfig struct {

func ValidateAndSetDefaultsOAuthAppManifest(r *types.OAuthAppManifest, create bool) error {
var errs []error
if r.Integration == "" {
errs = append(errs, fmt.Errorf("missing integration"))
} else if !alphaNumericRegexp.MatchString(r.Integration) {
errs = append(errs, fmt.Errorf("integration name can only contain alphanumeric characters and hyphens: %s", r.Integration))
if r.Alias == "" {
errs = append(errs, fmt.Errorf("missing alias"))
} else if !alphaNumericRegexp.MatchString(r.Alias) {
errs = append(errs, fmt.Errorf("alias name can only contain alphanumeric characters and hyphens: %s", r.Alias))
}

switch r.Type {
case types.OAuthAppTypeAtlassian:
r.AuthURL = AtlassianAuthorizeURL
r.TokenURL = AtlassianTokenURL
case types.OAuthAppTypeMicrosoft365:
r.AuthURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", r.TenantID)
r.TokenURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", r.TenantID)
tenantID := r.TenantID
if tenantID == "" {
tenantID = "common"
}
r.AuthURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenantID)
r.TokenURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenantID)
case types.OAuthAppTypeSlack:
r.AuthURL = SlackAuthorizeURL
r.TokenURL = SlackTokenURL
Expand Down Expand Up @@ -163,16 +167,16 @@ func MergeOAuthAppManifests(r, other types.OAuthAppManifest) types.OAuthAppManif
if other.Name != "" {
retVal.Name = other.Name
}
if other.Integration != "" {
retVal.Integration = other.Integration
if other.Alias != "" {
retVal.Alias = other.Alias
}
if other.AppID != "" {
retVal.AppID = other.AppID
}
if other.OptionalScope != "" {
retVal.OptionalScope = other.OptionalScope
}
if other.Global != nil {
if other.Global {
retVal.Global = other.Global
}

Expand Down
27 changes: 14 additions & 13 deletions pkg/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
}

if opts.Thread != nil {
for _, tool := range opts.Thread.Spec.Manifest.Tools {
if !added && tool == knowledgeToolName {
for _, t := range opts.Thread.Spec.Manifest.Tools {
if !added && t == knowledgeToolName {
continue
}
name, tools, err := Tool(ctx, db, agent.Namespace, tool)
name, tools, err := tool(ctx, db, agent.Namespace, t)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -107,17 +107,18 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
}
}

for _, tool := range agent.Spec.Manifest.Tools {
if !added && tool == knowledgeToolName {
for _, t := range agent.Spec.Manifest.Tools {
if !added && t == knowledgeToolName {
continue
}
name, tools, err := Tool(ctx, db, agent.Namespace, tool)
name, tools, err := tool(ctx, db, agent.Namespace, t)
if err != nil {
return nil, nil, err
}
if name != "" {
mainTool.Tools = append(mainTool.Tools, name)
}

otherTools = append(otherTools, tools...)
}

Expand Down Expand Up @@ -161,30 +162,30 @@ func OAuthAppEnv(ctx context.Context, db kclient.Client, oauthAppNames []string,
activeIntegrations := map[string]v1.OAuthApp{}
for _, name := range slices.Sorted(maps.Keys(apps)) {
app := apps[name]
if app.Spec.Manifest.Global == nil || !*app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Integration == "" {
if !app.Spec.Manifest.Global || app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" || app.Spec.Manifest.Alias == "" {
continue
}
activeIntegrations[app.Spec.Manifest.Integration] = app
activeIntegrations[app.Spec.Manifest.Alias] = app
}

for _, appRef := range oauthAppNames {
app, ok := apps[appRef]
if !ok {
return nil, fmt.Errorf("oauth app %s not found", appRef)
}
if app.Spec.Manifest.Integration == "" {
if app.Spec.Manifest.Alias == "" {
return nil, fmt.Errorf("oauth app %s has no integration name", app.Name)
}
if app.Spec.Manifest.ClientID == "" || app.Spec.Manifest.ClientSecret == "" {
return nil, fmt.Errorf("oauth app %s has no client id or secret", app.Name)
}

activeIntegrations[app.Spec.Manifest.Integration] = app
activeIntegrations[app.Spec.Manifest.Alias] = app
}

for _, integration := range slices.Sorted(maps.Keys(activeIntegrations)) {
app := activeIntegrations[integration]
integrationEnv := strings.ReplaceAll(strings.ToUpper(app.Spec.Manifest.Integration), "-", "_")
integrationEnv := strings.ReplaceAll(strings.ToUpper(app.Spec.Manifest.Alias), "-", "_")

extraEnv = append(extraEnv,
fmt.Sprintf("GPTSCRIPT_OAUTH_%s_AUTH_URL=%s", integrationEnv, app.AuthorizeURL(serverURL)),
Expand Down Expand Up @@ -351,8 +352,8 @@ func oauthAppsByName(ctx context.Context, c kclient.Client, namespace string) (m
}

for _, app := range apps.Items {
if app.Spec.Manifest.Integration != "" {
result[app.Spec.Manifest.Integration] = app
if app.Spec.Manifest.Alias != "" {
result[app.Spec.Manifest.Alias] = app
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/render/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ END INSTRUCTIONS: TOOL %q`, tool.Spec.Manifest.Name, tool.Spec.Manifest.Context,
return toolDefs, nil
}

func Tool(ctx context.Context, c client.Client, ns, name string) (_ string, toolDefs []gptscript.ToolDef, _ error) {
func tool(ctx context.Context, c client.Client, ns, name string) (string, []gptscript.ToolDef, error) {
if !system.IsToolID(name) {
name, err := ResolveToolReference(ctx, c, types.ToolReferenceTypeTool, ns, name)
name, err := resolveToolReferenceWithMetadata(ctx, c, types.ToolReferenceTypeTool, ns, name)
return name, nil, err
}

Expand Down
6 changes: 6 additions & 0 deletions pkg/render/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func IsExternalTool(tool string) bool {
}

func ResolveToolReference(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, error) {
name, err := resolveToolReferenceWithMetadata(ctx, c, toolRefType, ns, name)
return name, err
}

func resolveToolReferenceWithMetadata(ctx context.Context, c kclient.Client, toolRefType types.ToolReferenceType, ns, name string) (string, error) {
if IsExternalTool(name) {
return name, nil
}
Expand All @@ -39,6 +44,7 @@ func ResolveToolReference(ctx context.Context, c kclient.Client, toolRefType typ
} else if err != nil {
return "", err
}

if toolRefType != "" && tool.Spec.Type != toolRefType {
return name, fmt.Errorf("tool reference %s is not of type %s", name, toolRefType)
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/storage/apis/obot.obot.ai/v1/oauthapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type OAuthApp struct {
}

func (r *OAuthApp) GetAliasName() string {
return r.Spec.Manifest.Integration
return r.Spec.Manifest.Alias
}

func (r *OAuthApp) SetAssigned(bool) {}
Expand All @@ -48,32 +48,32 @@ func (r *OAuthApp) Has(field string) bool {
func (r *OAuthApp) Get(field string) string {
if r != nil {
switch field {
case "spec.manifest.integration":
return r.Spec.Manifest.Integration
case "spec.manifest.alias":
return r.Spec.Manifest.Alias
}
}

return ""
}

func (r *OAuthApp) FieldNames() []string {
return []string{"spec.manifest.integration"}
return []string{"spec.manifest.alias"}
}

func (r *OAuthApp) RedirectURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/callback/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/callback/%s", baseURL, r.Spec.Manifest.Alias)
}

func OAuthAppGetTokenURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/get-token", baseURL)
}

func (r *OAuthApp) AuthorizeURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/authorize/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/authorize/%s", baseURL, r.Spec.Manifest.Alias)
}

func (r *OAuthApp) RefreshURL(baseURL string) string {
return fmt.Sprintf("%s/api/app-oauth/refresh/%s", baseURL, r.Spec.Manifest.Integration)
return fmt.Sprintf("%s/api/app-oauth/refresh/%s", baseURL, r.Spec.Manifest.Alias)
}

func (r *OAuthApp) DeleteRefs() []Ref {
Expand Down
Loading