Skip to content

Commit 358bc58

Browse files
feat: adds filters for mcp session
1 parent 4385ce1 commit 358bc58

6 files changed

Lines changed: 723 additions & 98 deletions

File tree

framework/configstore/rdb.go

Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5138,33 +5138,45 @@ func (s *RDBConfigStore) GetOauthUserTokenByID(ctx context.Context, id string) (
51385138
return &token, nil
51395139
}
51405140

5141-
// ListAllOauthUserTokens returns every token row regardless of status. Used by
5142-
// the sessions tab UI, which renders distinct affordances per state; filtering
5143-
// here would only hide rows the user needs to see (especially needs_reauth).
5144-
// Runtime lookups apply their own status='active' filter and don't use this.
5145-
func (s *RDBConfigStore) ListAllOauthUserTokens(ctx context.Context) ([]tables.TableOauthUserToken, error) {
5141+
// ListOauthUserTokens returns token rows matching params, regardless of status.
5142+
// The sessions tab UI renders distinct affordances per state; default status
5143+
// filtering here would only hide rows the user needs to see (especially
5144+
// needs_reauth). Runtime lookups apply their own status='active' filter and
5145+
// don't use this. Pagination is handler-side because cross-table de-dup with
5146+
// the pending-session list happens after the merge.
5147+
func (s *RDBConfigStore) ListOauthUserTokens(ctx context.Context, params MCPSessionsFilterParams) ([]tables.TableOauthUserToken, error) {
5148+
query := s.ScopedDB(ctx).Model(&tables.TableOauthUserToken{})
5149+
query = applyMCPSessionFilters(query, params, mcpSessionFilterTable{
5150+
table: "oauth_user_tokens",
5151+
authModeColumn: "auth_mode",
5152+
})
51465153
var tokens []tables.TableOauthUserToken
5147-
if err := s.ScopedDB(ctx).
5154+
if err := query.
51485155
Preload("MCPClient", func(db *gorm.DB) *gorm.DB { return db.Select("client_id, name") }).
51495156
Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name") }).
5150-
Order("created_at DESC").
5157+
Order("oauth_user_tokens.created_at DESC").
51515158
Find(&tokens).Error; err != nil {
5152-
return nil, fmt.Errorf("failed to list all oauth user tokens: %w", err)
5159+
return nil, fmt.Errorf("failed to list oauth user tokens: %w", err)
51535160
}
51545161
return tokens, nil
51555162
}
51565163

5157-
// ListAllPendingOauthUserSessions returns all pending OAuth flow rows whose
5158-
// expiry is in the future. Companion to ListAllOauthUserTokens.
5159-
func (s *RDBConfigStore) ListAllPendingOauthUserSessions(ctx context.Context) ([]tables.TableOauthUserSession, error) {
5164+
// ListPendingOauthUserSessions returns pending OAuth flow rows matching params
5165+
// whose expiry is in the future. Companion to ListOauthUserTokens.
5166+
func (s *RDBConfigStore) ListPendingOauthUserSessions(ctx context.Context, params MCPSessionsFilterParams) ([]tables.TableOauthUserSession, error) {
5167+
query := s.ScopedDB(ctx).Model(&tables.TableOauthUserSession{}).
5168+
Where("oauth_user_sessions.status = ? AND oauth_user_sessions.expires_at > ?", "pending", time.Now())
5169+
query = applyMCPSessionFilters(query, params, mcpSessionFilterTable{
5170+
table: "oauth_user_sessions",
5171+
authModeColumn: "flow_mode",
5172+
})
51605173
var sessions []tables.TableOauthUserSession
5161-
if err := s.ScopedDB(ctx).
5174+
if err := query.
51625175
Preload("MCPClient", func(db *gorm.DB) *gorm.DB { return db.Select("client_id, name") }).
51635176
Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name") }).
5164-
Where("status = ? AND expires_at > ?", "pending", time.Now()).
5165-
Order("created_at DESC").
5177+
Order("oauth_user_sessions.created_at DESC").
51665178
Find(&sessions).Error; err != nil {
5167-
return nil, fmt.Errorf("failed to list all pending oauth user sessions: %w", err)
5179+
return nil, fmt.Errorf("failed to list pending oauth user sessions: %w", err)
51685180
}
51695181
return sessions, nil
51705182
}
@@ -5317,19 +5329,25 @@ func (s *RDBConfigStore) DeleteMCPPerUserHeaderCredential(ctx context.Context, i
53175329
return nil
53185330
}
53195331

5320-
// ListAllMCPPerUserHeaderCredentials returns every row regardless of status.
5321-
// The sessions UI surfaces non-active states (needs_update / orphaned) with
5322-
// distinct affordances; filtering here would only hide rows the user needs to
5323-
// act on. Runtime lookups apply their own status='active' filter and don't go
5324-
// through this method.
5325-
func (s *RDBConfigStore) ListAllMCPPerUserHeaderCredentials(ctx context.Context) ([]tables.TableMCPPerUserHeaderCredential, error) {
5332+
// ListMCPPerUserHeaderCredentials returns credential rows matching params,
5333+
// regardless of status. The sessions UI surfaces non-active states
5334+
// (needs_update / orphaned) with distinct affordances; default status
5335+
// filtering here would only hide rows the user needs to act on. Runtime
5336+
// lookups apply their own status='active' filter and don't go through
5337+
// this method.
5338+
func (s *RDBConfigStore) ListMCPPerUserHeaderCredentials(ctx context.Context, params MCPSessionsFilterParams) ([]tables.TableMCPPerUserHeaderCredential, error) {
5339+
query := s.ScopedDB(ctx).Model(&tables.TableMCPPerUserHeaderCredential{})
5340+
query = applyMCPSessionFilters(query, params, mcpSessionFilterTable{
5341+
table: "mcp_per_user_header_credentials",
5342+
authModeColumn: "auth_mode",
5343+
})
53265344
var creds []tables.TableMCPPerUserHeaderCredential
5327-
if err := s.ScopedDB(ctx).
5345+
if err := query.
53285346
Preload("MCPClient", func(db *gorm.DB) *gorm.DB { return db.Select("client_id, name") }).
53295347
Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name") }).
5330-
Order("created_at DESC").
5348+
Order("mcp_per_user_header_credentials.created_at DESC").
53315349
Find(&creds).Error; err != nil {
5332-
return nil, fmt.Errorf("failed to list all mcp per-user header credentials: %w", err)
5350+
return nil, fmt.Errorf("failed to list mcp per-user header credentials: %w", err)
53335351
}
53345352
return creds, nil
53355353
}
@@ -5485,23 +5503,69 @@ func (s *RDBConfigStore) DeleteMCPPerUserHeaderFlowsByModeIdentityAndMCPClient(c
54855503
return nil
54865504
}
54875505

5488-
// ListAllPendingMCPPerUserHeaderFlows returns all pending header-submission
5489-
// flow rows whose expiry is in the future. Uses ScopedDB so a query-scope
5490-
// stashed on ctx (if any) narrows the result; otherwise returns every row.
5491-
// Mirrors ListAllPendingOauthUserSessions.
5492-
func (s *RDBConfigStore) ListAllPendingMCPPerUserHeaderFlows(ctx context.Context) ([]tables.TableMCPPerUserHeaderFlow, error) {
5506+
// ListPendingMCPPerUserHeaderFlows returns pending header-submission flow rows
5507+
// matching params whose expiry is in the future. Uses ScopedDB so a
5508+
// query-scope stashed on ctx (if any) narrows the result; otherwise returns
5509+
// every matching pending row. Mirrors ListPendingOauthUserSessions.
5510+
func (s *RDBConfigStore) ListPendingMCPPerUserHeaderFlows(ctx context.Context, params MCPSessionsFilterParams) ([]tables.TableMCPPerUserHeaderFlow, error) {
5511+
query := s.ScopedDB(ctx).Model(&tables.TableMCPPerUserHeaderFlow{}).
5512+
Where("mcp_per_user_header_flows.status = ? AND mcp_per_user_header_flows.expires_at > ?", "pending", time.Now())
5513+
query = applyMCPSessionFilters(query, params, mcpSessionFilterTable{
5514+
table: "mcp_per_user_header_flows",
5515+
authModeColumn: "flow_mode",
5516+
})
54935517
var flows []tables.TableMCPPerUserHeaderFlow
5494-
if err := s.ScopedDB(ctx).
5518+
if err := query.
54955519
Preload("MCPClient", func(db *gorm.DB) *gorm.DB { return db.Select("client_id, name") }).
54965520
Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { return db.Select("id, name") }).
5497-
Where("status = ? AND expires_at > ?", "pending", time.Now()).
5498-
Order("created_at DESC").
5521+
Order("mcp_per_user_header_flows.created_at DESC").
54995522
Find(&flows).Error; err != nil {
5500-
return nil, fmt.Errorf("failed to list all pending mcp per-user header flows: %w", err)
5523+
return nil, fmt.Errorf("failed to list pending mcp per-user header flows: %w", err)
55015524
}
55025525
return flows, nil
55035526
}
55045527

5528+
// mcpSessionFilterTable carries the table-specific column names needed to
5529+
// build a generic filter chain. The auth-mode column is named differently
5530+
// on the credential tables ("auth_mode") and the pending-flow tables
5531+
// ("flow_mode"), but the value space is identical.
5532+
type mcpSessionFilterTable struct {
5533+
table string
5534+
authModeColumn string // "auth_mode" or "flow_mode"
5535+
}
5536+
5537+
// applyMCPSessionFilters appends the shared MCP-sessions WHERE clauses and
5538+
// the search LEFT JOINs to a query. The search JOINs (config_mcp_clients,
5539+
// governance_virtual_keys) and the LIKE WHERE are emitted only when
5540+
// params.Search is non-empty; when absent the columns are never referenced
5541+
// and no JOIN is added. The join cardinality is 1:1 on FK columns, so
5542+
// Count is safe without DISTINCT.
5543+
func applyMCPSessionFilters(query *gorm.DB, params MCPSessionsFilterParams, t mcpSessionFilterTable) *gorm.DB {
5544+
if len(params.Statuses) > 0 {
5545+
query = query.Where(t.table+".status IN ?", params.Statuses)
5546+
}
5547+
if len(params.AuthModes) > 0 {
5548+
query = query.Where(t.table+"."+t.authModeColumn+" IN ?", params.AuthModes)
5549+
}
5550+
if len(params.MCPClientIDs) > 0 {
5551+
query = query.Where(t.table+".mcp_client_id IN ?", params.MCPClientIDs)
5552+
}
5553+
if params.Search != "" {
5554+
needle := "%" + strings.ToLower(params.Search) + "%"
5555+
query = query.
5556+
Joins("LEFT JOIN config_mcp_clients ON config_mcp_clients.client_id = " + t.table + ".mcp_client_id").
5557+
Joins("LEFT JOIN governance_virtual_keys ON governance_virtual_keys.id = " + t.table + ".virtual_key_id")
5558+
whereClause := "LOWER(config_mcp_clients.name) LIKE ? OR LOWER(config_mcp_clients.client_id) LIKE ? OR LOWER(" + t.table + ".user_id) LIKE ? OR LOWER(" + t.table + ".session_id) LIKE ? OR LOWER(governance_virtual_keys.id) LIKE ? OR LOWER(governance_virtual_keys.name) LIKE ?"
5559+
whereArgs := []any{needle, needle, needle, needle, needle, needle}
5560+
if len(params.MatchedUserIDs) > 0 {
5561+
whereClause += " OR " + t.table + ".user_id IN ?"
5562+
whereArgs = append(whereArgs, params.MatchedUserIDs)
5563+
}
5564+
query = query.Where(whereClause, whereArgs...)
5565+
}
5566+
return query
5567+
}
5568+
55055569
// DeleteExpiredMCPPerUserHeaderFlows hard-deletes pending flow rows whose
55065570
// ExpiresAt has passed. Status filter excludes already-completed rows
55075571
// (which the submit path deletes immediately anyway).

0 commit comments

Comments
 (0)