diff --git a/pkg/api/authz/assistant.go b/pkg/api/authz/assistant.go index 23ee008d2..74c42e452 100644 --- a/pkg/api/authz/assistant.go +++ b/pkg/api/authz/assistant.go @@ -3,10 +3,7 @@ package authz import ( "context" "net/http" - "slices" - "strings" - "github.com/obot-platform/nah/pkg/router" "github.com/obot-platform/obot/pkg/alias" v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" "github.com/obot-platform/obot/pkg/system" @@ -23,117 +20,42 @@ func getValidUserIDs(user user.Info) []string { return keys } -func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool { - for _, userID := range validUserIDs { - var access v1.AgentAuthorizationList - err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{ - "spec.userID": userID, - "spec.agentID": agentID, - }) - if err == nil && len(access.Items) == 1 { - return true - } - } - return false -} - -func (a *Authorizer) threadIsAuthorized(ctx context.Context, agentID, projectID, threadID string, user user.Info) bool { - var thread v1.Thread - if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil { - return false - } - if thread.Spec.AgentName != agentID { - return false - } - if thread.Spec.ParentThreadName != strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1) { - return false - } - if thread.Spec.UserUID != user.GetUID() { - return false - } - return true -} - -func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID, projectID string, validUserIDs []string) bool { - var ( - thread v1.Thread - threadID = strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1) - ) - if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil { - return false - } - if !thread.Spec.Project { - return false - } - if thread.Spec.AgentName != agentID { - return false - } - if slices.Contains(validUserIDs, thread.Spec.UserUID) { - return true - } - - for _, userID := range validUserIDs { - var access v1.ThreadAuthorizationList - err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{ - "spec.userID": userID, - "spec.threadID": threadID, - "spec.accepted": "true", - }) - if err == nil && len(access.Items) == 1 { - return true - } - } - return false -} - -func (a *Authorizer) authorizeAssistant(req *http.Request, user user.Info) bool { - if !strings.HasPrefix(req.URL.Path, "/api/assistants/") { - return false - } - - paths := strings.Split(req.URL.Path, "/") - if paths[3] == "" { - return false - } - - // Must be authenticated - if !slices.Contains(user.GetGroups(), AuthenticatedGroup) { - return false +func (a *Authorizer) checkAssistant(req *http.Request, resources *Resources, user user.Info) (bool, error) { + if resources.AssistantID == "" { + return true, nil } var ( - agentID = paths[3] + agentID = resources.AssistantID validUserIDs = getValidUserIDs(user) + agent v1.Agent ) if !system.IsAgentID(agentID) { - var agent v1.Agent if err := alias.Get(req.Context(), a.storage, &agent, "", agentID); err != nil { - return false + return false, err } agentID = agent.Name } if !a.assistantIsAuthorized(req.Context(), agentID, validUserIDs) { - return false + return false, nil } - if len(paths) <= 5 || paths[4] != "projects" { - return true - } - - // Emails are authorized only here, so reverse the list - slices.Reverse(validUserIDs) - - var projectID = paths[5] - if !a.projectIsAuthorized(req.Context(), agentID, projectID, validUserIDs) { - return false - } + resources.Authorizated.Assistant = &agent + return true, nil +} - if len(paths) <= 7 || paths[6] != "threads" { - return true +func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool { + for _, userID := range validUserIDs { + var access v1.AgentAuthorizationList + err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{ + "spec.userID": userID, + "spec.agentID": agentID, + }) + if err == nil && len(access.Items) == 1 { + return true + } } - - var threadID = paths[7] - return a.threadIsAuthorized(req.Context(), agentID, projectID, threadID, user) + return false } diff --git a/pkg/api/authz/authz.go b/pkg/api/authz/authz.go index 6f8df0613..c4d077a78 100644 --- a/pkg/api/authz/authz.go +++ b/pkg/api/authz/authz.go @@ -56,6 +56,8 @@ var staticRules = map[string][]string{ "GET /api/auth-providers", "GET /api/auth-providers/{id}", + + "GET /o/{id}", }, AuthenticatedGroup: { "/api/oauth/redirect/{namespace}/{name}", @@ -81,15 +83,23 @@ var devModeRules = map[string][]string{ } type Authorizer struct { - rules []rule - storage kclient.Client + rules []rule + storage kclient.Client + resourcesMux *http.ServeMux } func NewAuthorizer(storage kclient.Client, devMode bool) *Authorizer { - return &Authorizer{ - rules: defaultRules(devMode), - storage: storage, + a := &Authorizer{ + rules: defaultRules(devMode), + storage: storage, + resourcesMux: http.NewServeMux(), + } + + for _, resource := range resources { + a.resourcesMux.HandleFunc(resource, a.evaluateResources) } + + return a } func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool { @@ -102,19 +112,7 @@ func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool { } } - if authorizeThread(req, user) { - return true - } - - if a.authorizeThreadFileDownload(req, user) { - return true - } - - if a.authorizeAssistant(req, user) { - return true - } - - return authorizeUI(req, user) + return a.authorizeResource(req, user) } type rule struct { diff --git a/pkg/api/authz/pendingauthorization.go b/pkg/api/authz/pendingauthorization.go new file mode 100644 index 000000000..df7695b25 --- /dev/null +++ b/pkg/api/authz/pendingauthorization.go @@ -0,0 +1,33 @@ +package authz + +import ( + "net/http" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" +) + +func (a *Authorizer) checkPendingAuthorization(req *http.Request, resources *Resources, user user.Info) (bool, error) { + if resources.PendingAuthorizationID == "" { + return true, nil + } + + var ( + threadAuth v1.ThreadAuthorization + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.PendingAuthorizationID), &threadAuth); err != nil { + return false, err + } + + for _, uid := range getValidUserIDs(user) { + if threadAuth.Spec.UserID == uid { + resources.Authorizated.PendingAuthorization = &threadAuth + return true, nil + } + } + + return true, nil +} diff --git a/pkg/api/authz/project.go b/pkg/api/authz/project.go new file mode 100644 index 000000000..f0878a663 --- /dev/null +++ b/pkg/api/authz/project.go @@ -0,0 +1,71 @@ +package authz + +import ( + "context" + "net/http" + "slices" + "strings" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" + kclient "sigs.k8s.io/controller-runtime/pkg/client" +) + +func (a *Authorizer) checkProject(req *http.Request, resources *Resources, user user.Info) (bool, error) { + if resources.ProjectID == "" { + return true, nil + } + + var ( + agentID string + validUserIDs = getValidUserIDs(user) + thread v1.Thread + projectThreadID = strings.Replace(resources.ProjectID, system.ProjectPrefix, system.ThreadPrefix, 1) + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, projectThreadID), &thread); err != nil { + return false, err + } + + if resources.Authorizated.Assistant != nil { + agentID = resources.Authorizated.Assistant.Name + } + + if !a.projectIsAuthorized(req.Context(), agentID, &thread, validUserIDs) { + return false, nil + } + + resources.Authorizated.Project = &thread + return true, nil +} + +func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID string, thread *v1.Thread, validUserIDs []string) bool { + if !thread.Spec.Project { + return false + } + if agentID != "" { + // If agent is available, make sure it's related + if thread.Spec.AgentName != agentID { + return false + } + } + + if slices.Contains(validUserIDs, thread.Spec.UserUID) { + return true + } + + for _, userID := range validUserIDs { + var access v1.ThreadAuthorizationList + err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{ + "spec.userID": userID, + "spec.threadID": thread.Name, + "spec.accepted": "true", + }) + if err == nil && len(access.Items) == 1 { + return true + } + } + return false +} diff --git a/pkg/api/authz/resources.go b/pkg/api/authz/resources.go new file mode 100644 index 000000000..54cfc0d37 --- /dev/null +++ b/pkg/api/authz/resources.go @@ -0,0 +1,239 @@ +package authz + +import ( + "bytes" + "context" + "io" + "net/http" + + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apiserver/pkg/authentication/user" +) + +var resources = []string{ + "GET /api/assistants", + "GET /api/assistants/{assistant_id}", + "GET /api/assistants/{assistant_id}/pending-authorizations", + "PUT /api/assistants/{assistant_id}/pending-authorizations/{pending_authorization_id}", + "DELETE /api/assistants/{assistant_id}/pending-authorizations/{pending_authorization_id}", + "GET /api/assistants/{assistant_id}/projects", + "POST /api/assistants/{assistant_id}/projects", + "GET /api/assistants/{assistant_id}/projects/{project_id}", + "PUT /api/assistants/{assistant_id}/projects/{project_id}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/authorizations", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/authorizations", + "GET /api/assistants/{assistant_id}/projects/{project_id}/credentials", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/credentials/{credential_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/env", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/env", + "GET /api/assistants/{assistant_id}/projects/{project_id}/file/{file...}", + "POST /api/assistants/{assistant_id}/projects/{project_id}/file/{file...}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/files", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/files/{file...}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/knowledge", + "POST /api/assistants/{assistant_id}/projects/{project_id}/knowledge/{file}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/knowledge/{file...}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/shell", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tables", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tables/{table_name}/rows", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{id}/runs/{run_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/events", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/events", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/run", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/abort", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/events", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/events", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/file/{file...}", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/file/{file...}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/files", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tasks/{task_id}/runs/{run_id}/files/{file...}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/templates", + "POST /api/assistants/{assistant_id}/projects/{project_id}/templates", + "GET /api/assistants/{assistant_id}/projects/{project_id}/templates/{template_id}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/templates/{template_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/threads", + "POST /api/assistants/{assistant_id}/projects/{project_id}/threads", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/threads/{thread_id}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/threads/{thread_id}", + "POST /api/assistants/{assistant_id}/projects/{project_id}/threads/{thread_id}/abort", + "GET /api/assistants/{assistant_id}/projects/{project_id}/threads/{thread_id}/events", + "POST /api/assistants/{assistant_id}/projects/{project_id}/threads/{thread_id}/invoke", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tools", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/tools", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tools", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/authenticate", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/custom", + "DELETE /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/deauthenticate", + "GET /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/env", + "PUT /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/env", + "POST /api/assistants/{assistant_id}/projects/{project_id}/tools/{tool_id}/test", + "GET /api/projects", + "POST /api/prompt", + "GET /api/templates", + "GET /api/templates/{template_id}", + "POST /api/templates/{template_id}/projects", + "PUT /api/threads/{thread_id}", + "DELETE /api/threads/{thread_id}", + "POST /api/threads/{thread_id}/abort", + "GET /api/threads/{thread_id}/events", + "DELETE /api/threads/{thread_id}/files/{file...}", + "GET /api/threads/{thread_id}/files/{file...}", + "POST /api/threads/{thread_id}/files/{file...}", + "GET /api/threads/{thread_id}/files", + "DELETE /api/threads/{thread_id}/knowledge-files/{file...}", + "POST /api/threads/{thread_id}/knowledge-files/{file}", + "GET /api/threads/{thread_id}/knowledge-files", + "GET /api/threads/{thread_id}/tables/{table}/rows", + "GET /api/threads/{thread_id}/tables", + "GET /api/threads/{thread_id}/tasks", + "POST /api/threads/{thread_id}/tasks", + "GET /api/threads/{thread_id}/tasks/{task_id}", + "PUT /api/threads/{thread_id}/tasks/{task_id}", + "POST /api/threads/{thread_id}/tasks/{task_id}/run", + "GET /api/threads/{thread_id}/tasks/{task_id}/runs", + "GET /api/threads/{thread_id}/tasks/{task_id}/runs/{run_id}", + "GET /api/threads/{thread_id}", + "GET /api/threads/{thread_id}/workflows", + "GET /api/threads/{thread_id}/workflows/{workflow_id}/executions", + "GET /{ui}/projects/{id}", +} + +type Resources struct { + AssistantID string + ProjectID string + ThreadID string + TemplateID string + TaskID string + RunID string + WorkflowID string + PendingAuthorizationID string + Authorizated ResourcesAuthorized +} + +type ResourcesAuthorized struct { + Assistant *v1.Agent + Project *v1.Thread + Thread *v1.Thread + Template *v1.ThreadTemplate + Task *v1.Workflow + Run *v1.WorkflowExecution + Workflow *v1.Workflow + PendingAuthorization *v1.ThreadAuthorization +} + +func handleError(rw http.ResponseWriter, err error) { + if apierrors.IsNotFound(err) { + http.Error(rw, err.Error(), http.StatusNotFound) + } else if err != nil { + http.Error(rw, err.Error(), http.StatusForbidden) + } else { + rw.WriteHeader(http.StatusForbidden) + } +} + +type userKey struct{} + +func (a *Authorizer) evaluateResources(rw http.ResponseWriter, req *http.Request) { + user, ok := req.Context().Value(userKey{}).(user.Info) + if !ok { + return + } + + resources := Resources{ + AssistantID: req.PathValue("assistant_id"), + ProjectID: req.PathValue("project_id"), + ThreadID: req.PathValue("thread_id"), + TemplateID: req.PathValue("template_id"), + TaskID: req.PathValue("task_id"), + RunID: req.PathValue("run_id"), + WorkflowID: req.PathValue("workflow_id"), + PendingAuthorizationID: req.PathValue("pending_authorization_id"), + } + + if ok, err := a.checkAssistant(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkProject(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkThread(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkTemplate(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkTask(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkRun(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkWorkflow(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkPendingAuthorization(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + if ok, err := a.checkUI(req, &resources, user); !ok || err != nil { + handleError(rw, err) + return + } + + rw.WriteHeader(http.StatusAccepted) +} + +type responseWriter struct { + io.Writer + code int +} + +func (r *responseWriter) Header() http.Header { + return http.Header{} +} + +func (r *responseWriter) WriteHeader(statusCode int) { + r.code = statusCode +} + +func (a *Authorizer) authorizeResource(req *http.Request, user user.Info) bool { + h, pattern := a.resourcesMux.Handler(req) + if pattern == "" { + return false + } + + buffer := bytes.NewBuffer(nil) + rw := responseWriter{ + Writer: buffer, + } + + h.ServeHTTP(&rw, req.WithContext(context.WithValue(req.Context(), userKey{}, user))) + return rw.code == http.StatusAccepted +} diff --git a/pkg/api/authz/run.go b/pkg/api/authz/run.go new file mode 100644 index 000000000..2c1cd056d --- /dev/null +++ b/pkg/api/authz/run.go @@ -0,0 +1,35 @@ +package authz + +import ( + "net/http" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" +) + +func (a *Authorizer) checkRun(req *http.Request, resources *Resources, _ user.Info) (bool, error) { + if resources.RunID == "" { + return true, nil + } + + if resources.Authorizated.Task == nil { + return false, nil + } + + var ( + wfe v1.WorkflowExecution + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.RunID), &wfe); err != nil { + return false, err + } + + if resources.Authorizated.Task.Name != wfe.Spec.WorkflowName { + return false, nil + } + + resources.Authorizated.Run = &wfe + return true, nil +} diff --git a/pkg/api/authz/task.go b/pkg/api/authz/task.go new file mode 100644 index 000000000..bd2f56c91 --- /dev/null +++ b/pkg/api/authz/task.go @@ -0,0 +1,35 @@ +package authz + +import ( + "net/http" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" +) + +func (a *Authorizer) checkTask(req *http.Request, resources *Resources, _ user.Info) (bool, error) { + if resources.TaskID == "" { + return true, nil + } + + if resources.Authorizated.Project == nil { + return false, nil + } + + var ( + workflow v1.Workflow + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.TaskID), &workflow); err != nil { + return false, err + } + + if resources.Authorizated.Project.Name != workflow.Spec.ThreadName { + return false, nil + } + + resources.Authorizated.Task = &workflow + return true, nil +} diff --git a/pkg/api/authz/template.go b/pkg/api/authz/template.go new file mode 100644 index 000000000..a9683859e --- /dev/null +++ b/pkg/api/authz/template.go @@ -0,0 +1,69 @@ +package authz + +import ( + "context" + "net/http" + "slices" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" + kclient "sigs.k8s.io/controller-runtime/pkg/client" +) + +func (a *Authorizer) checkTemplate(req *http.Request, resources *Resources, user user.Info) (bool, error) { + if resources.TemplateID == "" { + return true, nil + } + + var ( + agentID string + validUserIDs = getValidUserIDs(user) + template v1.ThreadTemplate + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.TemplateID), &template); err != nil { + return false, err + } + + if resources.Authorizated.Project != nil { + return resources.Authorizated.Project.Name == template.Spec.ProjectThreadName, nil + } + + if resources.Authorizated.Assistant != nil { + agentID = resources.Authorizated.Assistant.Name + } + + if !a.templateIsAuthorized(req.Context(), agentID, &template, validUserIDs) { + return false, nil + } + + resources.Authorizated.Template = &template + return true, nil +} + +func (a *Authorizer) templateIsAuthorized(ctx context.Context, agentID string, template *v1.ThreadTemplate, validUserIDs []string) bool { + if agentID != "" { + // If agent is available, make sure it's related + if template.Status.AgentName != agentID { + return false + } + } + + if slices.Contains(validUserIDs, template.Spec.UserID) { + return true + } + + for _, userID := range validUserIDs { + var access v1.ThreadTemplateAuthorizationList + err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{ + "spec.userID": userID, + "spec.templateID": template.Name, + }) + if err == nil && len(access.Items) == 1 { + return true + } + } + return false +} diff --git a/pkg/api/authz/thread.go b/pkg/api/authz/thread.go index e42a58885..a9580b232 100644 --- a/pkg/api/authz/thread.go +++ b/pkg/api/authz/thread.go @@ -2,7 +2,6 @@ package authz import ( "net/http" - "strings" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/obot-platform/nah/pkg/router" @@ -11,70 +10,37 @@ import ( "k8s.io/apiserver/pkg/authentication/user" ) -func authorizeThread(req *http.Request, user user.Info) bool { - thread := types.FirstSet(user.GetExtra()["obot:threadID"]...) - agent := types.FirstSet(user.GetExtra()["obot:agentID"]...) - if thread == "" || agent == "" { - return false - } - if req.Method == "GET" && strings.HasPrefix(req.URL.Path, "/api/threads/"+thread+"/") { - return true - } - if req.Method == "POST" && strings.HasPrefix(req.URL.Path, "/api/threads/"+thread+"/tasks/") { - return true - } - - return false -} - -func (a *Authorizer) authorizeThreadFileDownload(req *http.Request, user user.Info) bool { - if req.Method != http.MethodGet { - return false - } - - if !strings.HasPrefix(req.URL.Path, "/api/threads/") { - return false - } - - parts := strings.Split(req.URL.Path, "/") - if len(parts) < 6 { - return false - } - if parts[0] != "" || - parts[1] != "api" || - parts[2] != "threads" || - parts[4] != "files" { - return false +func (a *Authorizer) checkThread(req *http.Request, resources *Resources, user user.Info) (bool, error) { + if resources.ThreadID == "" { + return true, nil } var ( - id = parts[3] thread v1.Thread ) - if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, id), &thread); err != nil { - return false - } - if thread.Spec.UserUID == user.GetUID() { - return true + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.ThreadID), &thread); err != nil { + return false, err } - if thread.Spec.WorkflowName == "" { - return false + if thread.Spec.Project { + return false, nil } - var workflow v1.Workflow - if err := a.storage.Get(req.Context(), router.Key(thread.Namespace, thread.Spec.WorkflowName), &workflow); err != nil { - return false - } + if resources.Authorizated.Project == nil { + threadID := types.FirstSet(user.GetExtra()["obot:threadID"]...) + agentID := types.FirstSet(user.GetExtra()["obot:agentID"]...) + if threadID == "" || agentID == "" { + return false, nil + } - if workflow.Spec.ThreadName == "" { - return false + return threadID == thread.Name && thread.Spec.AgentName == agentID, nil } - if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, workflow.Spec.ThreadName), &thread); err != nil { - return false + if resources.Authorizated.Project.Name != thread.Spec.ParentThreadName { + return false, nil } - return thread.Spec.UserUID == user.GetUID() + resources.Authorizated.Thread = &thread + return true, nil } diff --git a/pkg/api/authz/ui.go b/pkg/api/authz/ui.go index 2483e985a..f003466ce 100644 --- a/pkg/api/authz/ui.go +++ b/pkg/api/authz/ui.go @@ -2,23 +2,15 @@ package authz import ( "net/http" - "strings" "k8s.io/apiserver/pkg/authentication/user" ) -func authorizeUI(req *http.Request, _ user.Info) bool { - if req.Method != http.MethodGet { - return false +func (a *Authorizer) checkUI(req *http.Request, _ *Resources, _ user.Info) (bool, error) { + var ui = req.PathValue("ui") + if ui == "" { + return true, nil } - if strings.HasPrefix(req.URL.Path, "/api") { - return false - } - - parts := strings.Split(req.URL.Path, "/") - if len(parts) > 2 && parts[2] == "projects" { - return true - } - - return false + // Ensure the URL does not start with /api + return ui != "api", nil } diff --git a/pkg/api/authz/workflow.go b/pkg/api/authz/workflow.go new file mode 100644 index 000000000..9b83d3464 --- /dev/null +++ b/pkg/api/authz/workflow.go @@ -0,0 +1,35 @@ +package authz + +import ( + "net/http" + + "github.com/obot-platform/nah/pkg/router" + v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" + "github.com/obot-platform/obot/pkg/system" + "k8s.io/apiserver/pkg/authentication/user" +) + +func (a *Authorizer) checkWorkflow(req *http.Request, resources *Resources, _ user.Info) (bool, error) { + if resources.WorkflowID == "" { + return true, nil + } + + if resources.Authorizated.Thread == nil { + return false, nil + } + + var ( + workflow v1.Workflow + ) + + if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.WorkflowID), &workflow); err != nil { + return false, err + } + + if resources.Authorizated.Thread.Name != workflow.Spec.ThreadName { + return false, nil + } + + resources.Authorizated.Workflow = &workflow + return true, nil +} diff --git a/ui/user/src/routes/o/+page.svelte b/ui/user/src/routes/o/[id]/+page.svelte similarity index 84% rename from ui/user/src/routes/o/+page.svelte rename to ui/user/src/routes/o/[id]/+page.svelte index ea5133158..39ef72009 100644 --- a/ui/user/src/routes/o/+page.svelte +++ b/ui/user/src/routes/o/[id]/+page.svelte @@ -2,12 +2,14 @@ import New from '$lib/components/New.svelte'; import { onMount } from 'svelte'; import { assistants, darkMode } from '$lib/stores'; + import { page } from '$app/stores'; + import { get } from 'svelte/store'; let dialog: ReturnType; onMount(async () => { await assistants.load(); - dialog?.show(); + dialog?.show(get(page).params.id); }); diff --git a/ui/user/src/routes/o/[id]/+page.ts b/ui/user/src/routes/o/[id]/+page.ts new file mode 100644 index 000000000..df46931d7 --- /dev/null +++ b/ui/user/src/routes/o/[id]/+page.ts @@ -0,0 +1 @@ +export const prerender = 'auto';