Skip to content

Commit 130c4fa

Browse files
chore: refactor authorization
1 parent 51a74f6 commit 130c4fa

File tree

13 files changed

+582
-184
lines changed

13 files changed

+582
-184
lines changed

pkg/api/authz/assistant.go

Lines changed: 21 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ package authz
33
import (
44
"context"
55
"net/http"
6-
"slices"
7-
"strings"
86

9-
"github.com/obot-platform/nah/pkg/router"
107
"github.com/obot-platform/obot/pkg/alias"
118
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
129
"github.com/obot-platform/obot/pkg/system"
@@ -23,117 +20,42 @@ func getValidUserIDs(user user.Info) []string {
2320
return keys
2421
}
2522

26-
func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool {
27-
for _, userID := range validUserIDs {
28-
var access v1.AgentAuthorizationList
29-
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
30-
"spec.userID": userID,
31-
"spec.agentID": agentID,
32-
})
33-
if err == nil && len(access.Items) == 1 {
34-
return true
35-
}
36-
}
37-
return false
38-
}
39-
40-
func (a *Authorizer) threadIsAuthorized(ctx context.Context, agentID, projectID, threadID string, user user.Info) bool {
41-
var thread v1.Thread
42-
if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil {
43-
return false
44-
}
45-
if thread.Spec.AgentName != agentID {
46-
return false
47-
}
48-
if thread.Spec.ParentThreadName != strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1) {
49-
return false
50-
}
51-
if thread.Spec.UserUID != user.GetUID() {
52-
return false
53-
}
54-
return true
55-
}
56-
57-
func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID, projectID string, validUserIDs []string) bool {
58-
var (
59-
thread v1.Thread
60-
threadID = strings.Replace(projectID, system.ProjectPrefix, system.ThreadPrefix, 1)
61-
)
62-
if err := a.storage.Get(ctx, router.Key(system.DefaultNamespace, threadID), &thread); err != nil {
63-
return false
64-
}
65-
if !thread.Spec.Project {
66-
return false
67-
}
68-
if thread.Spec.AgentName != agentID {
69-
return false
70-
}
71-
if slices.Contains(validUserIDs, thread.Spec.UserUID) {
72-
return true
73-
}
74-
75-
for _, userID := range validUserIDs {
76-
var access v1.ThreadAuthorizationList
77-
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
78-
"spec.userID": userID,
79-
"spec.threadID": threadID,
80-
"spec.accepted": "true",
81-
})
82-
if err == nil && len(access.Items) == 1 {
83-
return true
84-
}
85-
}
86-
return false
87-
}
88-
89-
func (a *Authorizer) authorizeAssistant(req *http.Request, user user.Info) bool {
90-
if !strings.HasPrefix(req.URL.Path, "/api/assistants/") {
91-
return false
92-
}
93-
94-
paths := strings.Split(req.URL.Path, "/")
95-
if paths[3] == "" {
96-
return false
97-
}
98-
99-
// Must be authenticated
100-
if !slices.Contains(user.GetGroups(), AuthenticatedGroup) {
101-
return false
23+
func (a *Authorizer) checkAssistant(req *http.Request, resources *Resources, user user.Info) (bool, error) {
24+
if resources.AssistantID == "" {
25+
return true, nil
10226
}
10327

10428
var (
105-
agentID = paths[3]
29+
agentID = resources.AssistantID
10630
validUserIDs = getValidUserIDs(user)
31+
agent v1.Agent
10732
)
10833

10934
if !system.IsAgentID(agentID) {
110-
var agent v1.Agent
11135
if err := alias.Get(req.Context(), a.storage, &agent, "", agentID); err != nil {
112-
return false
36+
return false, err
11337
}
11438
agentID = agent.Name
11539
}
11640

11741
if !a.assistantIsAuthorized(req.Context(), agentID, validUserIDs) {
118-
return false
42+
return false, nil
11943
}
12044

121-
if len(paths) <= 5 || paths[4] != "projects" {
122-
return true
123-
}
124-
125-
// Emails are authorized only here, so reverse the list
126-
slices.Reverse(validUserIDs)
127-
128-
var projectID = paths[5]
129-
if !a.projectIsAuthorized(req.Context(), agentID, projectID, validUserIDs) {
130-
return false
131-
}
45+
resources.Authorizated.Assistant = &agent
46+
return true, nil
47+
}
13248

133-
if len(paths) <= 7 || paths[6] != "threads" {
134-
return true
49+
func (a *Authorizer) assistantIsAuthorized(ctx context.Context, agentID string, validUserIDs []string) bool {
50+
for _, userID := range validUserIDs {
51+
var access v1.AgentAuthorizationList
52+
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
53+
"spec.userID": userID,
54+
"spec.agentID": agentID,
55+
})
56+
if err == nil && len(access.Items) == 1 {
57+
return true
58+
}
13559
}
136-
137-
var threadID = paths[7]
138-
return a.threadIsAuthorized(req.Context(), agentID, projectID, threadID, user)
60+
return false
13961
}

pkg/api/authz/authz.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ var staticRules = map[string][]string{
5656

5757
"GET /api/auth-providers",
5858
"GET /api/auth-providers/{id}",
59+
60+
"GET /o/{id}",
5961
},
6062
AuthenticatedGroup: {
6163
"/api/oauth/redirect/{namespace}/{name}",
@@ -81,15 +83,23 @@ var devModeRules = map[string][]string{
8183
}
8284

8385
type Authorizer struct {
84-
rules []rule
85-
storage kclient.Client
86+
rules []rule
87+
storage kclient.Client
88+
resourcesMux *http.ServeMux
8689
}
8790

8891
func NewAuthorizer(storage kclient.Client, devMode bool) *Authorizer {
89-
return &Authorizer{
90-
rules: defaultRules(devMode),
91-
storage: storage,
92+
a := &Authorizer{
93+
rules: defaultRules(devMode),
94+
storage: storage,
95+
resourcesMux: http.NewServeMux(),
96+
}
97+
98+
for _, resource := range resources {
99+
a.resourcesMux.HandleFunc(resource, a.evaluateResources)
92100
}
101+
102+
return a
93103
}
94104

95105
func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool {
@@ -102,19 +112,7 @@ func (a *Authorizer) Authorize(req *http.Request, user user.Info) bool {
102112
}
103113
}
104114

105-
if authorizeThread(req, user) {
106-
return true
107-
}
108-
109-
if a.authorizeThreadFileDownload(req, user) {
110-
return true
111-
}
112-
113-
if a.authorizeAssistant(req, user) {
114-
return true
115-
}
116-
117-
return authorizeUI(req, user)
115+
return a.authorizeResource(req, user)
118116
}
119117

120118
type rule struct {

pkg/api/authz/pendingauthorization.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package authz
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/obot-platform/nah/pkg/router"
7+
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
8+
"github.com/obot-platform/obot/pkg/system"
9+
"k8s.io/apiserver/pkg/authentication/user"
10+
)
11+
12+
func (a *Authorizer) checkPendingAuthorization(req *http.Request, resources *Resources, user user.Info) (bool, error) {
13+
if resources.PendingAuthorizationID == "" {
14+
return true, nil
15+
}
16+
17+
var (
18+
threadAuth v1.ThreadAuthorization
19+
)
20+
21+
if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, resources.PendingAuthorizationID), &threadAuth); err != nil {
22+
return false, err
23+
}
24+
25+
for _, uid := range getValidUserIDs(user) {
26+
if threadAuth.Spec.UserID == uid {
27+
resources.Authorizated.PendingAuthorization = &threadAuth
28+
return true, nil
29+
}
30+
}
31+
32+
return true, nil
33+
}

pkg/api/authz/project.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package authz
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"slices"
7+
"strings"
8+
9+
"github.com/obot-platform/nah/pkg/router"
10+
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
11+
"github.com/obot-platform/obot/pkg/system"
12+
"k8s.io/apiserver/pkg/authentication/user"
13+
kclient "sigs.k8s.io/controller-runtime/pkg/client"
14+
)
15+
16+
func (a *Authorizer) checkProject(req *http.Request, resources *Resources, user user.Info) (bool, error) {
17+
if resources.ProjectID == "" {
18+
return true, nil
19+
}
20+
21+
var (
22+
agentID string
23+
validUserIDs = getValidUserIDs(user)
24+
thread v1.Thread
25+
projectThreadID = strings.Replace(resources.ProjectID, system.ProjectPrefix, system.ThreadPrefix, 1)
26+
)
27+
28+
if err := a.storage.Get(req.Context(), router.Key(system.DefaultNamespace, projectThreadID), &thread); err != nil {
29+
return false, err
30+
}
31+
32+
if resources.Authorizated.Assistant != nil {
33+
agentID = resources.Authorizated.Assistant.Name
34+
}
35+
36+
if !a.projectIsAuthorized(req.Context(), agentID, &thread, validUserIDs) {
37+
return false, nil
38+
}
39+
40+
resources.Authorizated.Project = &thread
41+
return true, nil
42+
}
43+
44+
func (a *Authorizer) projectIsAuthorized(ctx context.Context, agentID string, thread *v1.Thread, validUserIDs []string) bool {
45+
if !thread.Spec.Project {
46+
return false
47+
}
48+
if agentID != "" {
49+
// If agent is available, make sure it's related
50+
if thread.Spec.AgentName != agentID {
51+
return false
52+
}
53+
}
54+
55+
if slices.Contains(validUserIDs, thread.Spec.UserUID) {
56+
return true
57+
}
58+
59+
for _, userID := range validUserIDs {
60+
var access v1.ThreadAuthorizationList
61+
err := a.storage.List(ctx, &access, kclient.InNamespace(system.DefaultNamespace), kclient.MatchingFields{
62+
"spec.userID": userID,
63+
"spec.threadID": thread.Name,
64+
"spec.accepted": "true",
65+
})
66+
if err == nil && len(access.Items) == 1 {
67+
return true
68+
}
69+
}
70+
return false
71+
}

0 commit comments

Comments
 (0)