Skip to content

Commit

Permalink
chore: refactor authorization
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Feb 21, 2025
1 parent 51a74f6 commit 130c4fa
Show file tree
Hide file tree
Showing 13 changed files with 582 additions and 184 deletions.
120 changes: 21 additions & 99 deletions pkg/api/authz/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ package authz
import (
"context"
"net/http"
"slices"
"strings"

"github.com/obot-platform/nah/pkg/router"
"github.com/obot-platform/obot/pkg/alias"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/system"
Expand All @@ -23,117 +20,42 @@ func getValidUserIDs(user user.Info) []string {
return keys
}

func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool {
for _, userID := range validUserIDs {
var access v1.AgentAuthorizationList
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
"spec.userID": userID,
"spec.agentID": agentID,
})
if err == nil && len(access.Items) == 1 {
return true
}
}
return false
}

func (a *Authorizer) threadIsAuthorized(ctx context.Context, agentID, projectID, threadID string, user user.Info) bool {
var thread v1.Thread
if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil {
return false
}
if thread.Spec.AgentName != agentID {
return false
}
if thread.Spec.ParentThreadName != strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1) {
return false
}
if thread.Spec.UserUID != user.GetUID() {
return false
}
return true
}

func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID, projectID string, validUserIDs []string) bool {
var (
thread v1.Thread
threadID = strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1)
)
if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil {
return false
}
if !thread.Spec.Project {
return false
}
if thread.Spec.AgentName != agentID {
return false
}
if slices.Contains(validUserIDs, thread.Spec.UserUID) {
return true
}

for _, userID := range validUserIDs {
var access v1.ThreadAuthorizationList
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
"spec.userID": userID,
"spec.threadID": threadID,
"spec.accepted": "true",
})
if err == nil && len(access.Items) == 1 {
return true
}
}
return false
}

func (a *Authorizer) authorizeAssistant(req *http.Request, user user.Info) bool {
if !strings.HasPrefix(req.URL.Path, "/api/assistants/") {
return false
}

paths := strings.Split(req.URL.Path, "/")
if paths[3] == "" {
return false
}

// Must be authenticated
if !slices.Contains(user.GetGroups(), AuthenticatedGroup) {
return false
func (a *Authorizer) checkAssistant(req *http.Request, resources *Resources, user user.Info) (bool, error) {
if resources.AssistantID == "" {
return true, nil
}

var (
agentID = paths[3]
agentID = resources.AssistantID
validUserIDs = getValidUserIDs(user)
agent v1.Agent
)

if !system.IsAgentID(agentID) {
var agent v1.Agent
if err := alias.Get(req.Context(), a.storage, &agent, "", agentID); err != nil {
return false
return false, err
}
agentID = agent.Name
}

if !a.assistantIsAuthorized(req.Context(), agentID, validUserIDs) {
return false
return false, nil
}

if len(paths) <= 5 || paths[4] != "projects" {
return true
}

// Emails are authorized only here, so reverse the list
slices.Reverse(validUserIDs)

var projectID = paths[5]
if !a.projectIsAuthorized(req.Context(), agentID, projectID, validUserIDs) {
return false
}
resources.Authorizated.Assistant = &agent
return true, nil
}

if len(paths) <= 7 || paths[6] != "threads" {
return true
func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool {
for _, userID := range validUserIDs {
var access v1.AgentAuthorizationList
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
"spec.userID": userID,
"spec.agentID": agentID,
})
if err == nil && len(access.Items) == 1 {
return true
}
}

var threadID = paths[7]
return a.threadIsAuthorized(req.Context(), agentID, projectID, threadID, user)
return false
}
34 changes: 16 additions & 18 deletions pkg/api/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ var staticRules = map[string][]string{

"GET /api/auth-providers",
"GET /api/auth-providers/{id}",

"GET /o/{id}",
},
AuthenticatedGroup: {
"/api/oauth/redirect/{namespace}/{name}",
Expand All @@ -81,15 +83,23 @@ var devModeRules = map[string][]string{
}

type Authorizer struct {
rules []rule
storage kclient.Client
rules []rule
storage kclient.Client
resourcesMux *http.ServeMux
}

func NewAuthorizer(storage kclient.Client, devMode bool) *Authorizer {
return &Authorizer{
rules: defaultRules(devMode),
storage: storage,
a := &Authorizer{
rules: defaultRules(devMode),
storage: storage,
resourcesMux: http.NewServeMux(),
}

for _, resource := range resources {
a.resourcesMux.HandleFunc(resource, a.evaluateResources)
}

return a
}

func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool {
Expand All @@ -102,19 +112,7 @@ func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool {
}
}

if authorizeThread(req, user) {
return true
}

if a.authorizeThreadFileDownload(req, user) {
return true
}

if a.authorizeAssistant(req, user) {
return true
}

return authorizeUI(req, user)
return a.authorizeResource(req, user)
}

type rule struct {
Expand Down
33 changes: 33 additions & 0 deletions pkg/api/authz/pendingauthorization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package authz

import (
"net/http"

"github.com/obot-platform/nah/pkg/router"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/system"
"k8s.io/apiserver/pkg/authentication/user"
)

func (a *Authorizer) checkPendingAuthorization(req *http.Request, resources *Resources, user user.Info) (bool, error) {
if resources.PendingAuthorizationID == "" {
return true, nil
}

var (
threadAuth v1.ThreadAuthorization
)

if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.PendingAuthorizationID), &threadAuth); err != nil {
return false, err
}

for _, uid := range getValidUserIDs(user) {
if threadAuth.Spec.UserID == uid {
resources.Authorizated.PendingAuthorization = &threadAuth
return true, nil
}
}

return true, nil
}
71 changes: 71 additions & 0 deletions pkg/api/authz/project.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package authz

import (
"context"
"net/http"
"slices"
"strings"

"github.com/obot-platform/nah/pkg/router"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/system"
"k8s.io/apiserver/pkg/authentication/user"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

func (a *Authorizer) checkProject(req *http.Request, resources *Resources, user user.Info) (bool, error) {
if resources.ProjectID == "" {
return true, nil
}

var (
agentID string
validUserIDs = getValidUserIDs(user)
thread v1.Thread
projectThreadID = strings.Replace(resources.ProjectID, system.ProjectPrefix, system.ThreadPrefix, 1)
)

if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, projectThreadID), &thread); err != nil {
return false, err
}

if resources.Authorizated.Assistant != nil {
agentID = resources.Authorizated.Assistant.Name
}

if !a.projectIsAuthorized(req.Context(), agentID, &thread, validUserIDs) {
return false, nil
}

resources.Authorizated.Project = &thread
return true, nil
}

func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID string, thread *v1.Thread, validUserIDs []string) bool {
if !thread.Spec.Project {
return false
}
if agentID != "" {
// If agent is available, make sure it's related
if thread.Spec.AgentName != agentID {
return false
}
}

if slices.Contains(validUserIDs, thread.Spec.UserUID) {
return true
}

for _, userID := range validUserIDs {
var access v1.ThreadAuthorizationList
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
"spec.userID": userID,
"spec.threadID": thread.Name,
"spec.accepted": "true",
})
if err == nil && len(access.Items) == 1 {
return true
}
}
return false
}
Loading

0 comments on commit 130c4fa

Please sign in to comment.