Skip to content

Commit 2c05f60

Browse files
feat: backend for per user headers mcp auth added
1 parent f04477f commit 2c05f60

36 files changed

Lines changed: 2863 additions & 114 deletions

core/bifrost.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
233233
requestQueues: sync.Map{},
234234
waitGroups: sync.Map{},
235235
keySelector: config.KeySelector,
236-
mcpCredStore: credstore.NewCredStore(config.OAuth2Provider, config.Logger),
236+
mcpCredStore: credstore.NewCredStore(config.OAuth2Provider, config.MCPHeadersProvider, config.Logger),
237237
logger: config.Logger,
238238
kvStore: config.KVStore,
239239
}
@@ -3824,6 +3824,34 @@ func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config
38243824
return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken)
38253825
}
38263826

3827+
// VerifyHeadersConnection delegates to the MCP manager to verify an MCP
3828+
// server using caller-supplied header values (admin sample or user-submitted)
3829+
// and discover available tools. Mirrors VerifyPerUserOAuthConnection's lazy
3830+
// MCP-manager init.
3831+
func (bifrost *Bifrost) VerifyHeadersConnection(ctx context.Context, config *schemas.MCPClientConfig, userHeaders map[string]string) (map[string]schemas.ChatTool, map[string]string, error) {
3832+
if bifrost.MCPManager == nil {
3833+
bifrost.mcpInitOnce.Do(func() {
3834+
mcpConfig := schemas.MCPConfig{
3835+
ClientConfigs: []*schemas.MCPClientConfig{},
3836+
}
3837+
mcpConfig.PluginPipelineProvider = func() interface{} {
3838+
return bifrost.getPluginPipeline()
3839+
}
3840+
mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) {
3841+
if pp, ok := pipeline.(*PluginPipeline); ok {
3842+
bifrost.releasePluginPipeline(pp)
3843+
}
3844+
}
3845+
codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger)
3846+
bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.mcpCredStore, bifrost.logger, codeMode)
3847+
})
3848+
}
3849+
if bifrost.MCPManager == nil {
3850+
return nil, nil, fmt.Errorf("MCP manager is not initialized")
3851+
}
3852+
return bifrost.MCPManager.VerifyHeadersConnection(ctx, config, userHeaders)
3853+
}
3854+
38273855
// SetClientTools delegates to the MCP manager to update the tool map for an
38283856
// existing MCP client.
38293857
func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) {

core/mcp/agent.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func (a *AgentModeExecutor) executeAgent(
286286
wg := sync.WaitGroup{}
287287
wg.Add(len(autoExecutableTools))
288288
channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools))
289-
var authRequiredErr *schemas.MCPUserOAuthRequiredError
289+
var authRequiredErr *schemas.MCPAuthRequiredError
290290
var authRequiredOnce sync.Once
291291
for _, toolCall := range autoExecutableTools {
292292
go func(toolCall schemas.ChatAssistantMessageToolCall) {
@@ -304,11 +304,11 @@ func (a *AgentModeExecutor) executeAgent(
304304

305305
mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest)
306306
if toolErr != nil {
307-
// Check if this is a per-user OAuth auth-required error
308-
var oauthErr *schemas.MCPUserOAuthRequiredError
309-
if errors.As(toolErr, &oauthErr) {
307+
// Check if this is a per-user auth-required error
308+
var authErr *schemas.MCPAuthRequiredError
309+
if errors.As(toolErr, &authErr) {
310310
authRequiredOnce.Do(func() {
311-
authRequiredErr = oauthErr
311+
authRequiredErr = authErr
312312
})
313313
channelToolResults <- createToolResultMessage(toolCall, "", toolErr)
314314
return

core/mcp/clientmanager.go

Lines changed: 154 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,20 @@ func (m *MCPManager) AcquireClientConn(ctx *schemas.BifrostContext, state *schem
6666
// Closure-captured outputs from the op so the caller can CallTool on the
6767
// live client after the gate returns.
6868
var tempClient *client.Client
69-
// MCPUserOAuthRequiredError is wrapped into a generic BifrostError by the
69+
// MCPAuthRequiredError is wrapped into a generic BifrostError by the
7070
// pipeline before PostConnectionHook runs, so capture it out-of-band to
71-
// preserve the typed-error info for the envelope path.
72-
var oauthErr *schemas.MCPUserOAuthRequiredError
71+
// preserve the typed-error info for the envelope path. Same capture
72+
// covers both per-user-OAuth (Kind=oauth) and per-user-headers
73+
// (Kind=headers) surfaces.
74+
var authRequiredErr *schemas.MCPAuthRequiredError
7375
start := time.Now()
7476

7577
_, gateErr := m.runConnectWithPluginPipeline(ctx, connectReq, func(preReq *schemas.BifrostMCPConnectRequest) (*schemas.BifrostMCPConnectResponse, error) {
7678
// Resolve auth headers AFTER PreConnectionHook ran. Plugins never see
7779
// the Authorization header — it lives only on the wire transport.
7880
authHeaders, credErr := m.credStore.ConnectionHeaders(ctx, config)
7981
if credErr != nil {
80-
errors.As(credErr, &oauthErr)
82+
errors.As(credErr, &authRequiredErr)
8183
return nil, credErr
8284
}
8385

@@ -158,8 +160,8 @@ func (m *MCPManager) AcquireClientConn(ctx *schemas.BifrostContext, state *schem
158160
if tempClient != nil {
159161
_ = tempClient.Close()
160162
}
161-
if oauthErr != nil {
162-
return nil, nil, oauthErr
163+
if authRequiredErr != nil {
164+
return nil, nil, authRequiredErr
163165
}
164166
if gateErr.Error != nil {
165167
return nil, nil, fmt.Errorf("%s", gateErr.Error.Message)
@@ -326,17 +328,20 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error {
326328
url := config.ConnectionString.GetValue()
327329
client.ConnectionInfo.ConnectionURL = &url
328330
}
329-
// Restore discovered tools from config (persisted in DB across restarts)
331+
// Restore discovered tools from config (persisted in DB across restarts).
332+
// Applies to every per-call-connection auth type — currently per-user
333+
// OAuth and per-user headers — since both populate DiscoveredTools at
334+
// admin-test time and never hold a persistent client.Conn.
330335
if len(config.DiscoveredTools) > 0 {
331336
for toolName, tool := range config.DiscoveredTools {
332337
client.ToolMap[toolName] = tool
333338
}
334339
client.ToolNameMapping = config.DiscoveredToolNameMapping
335340
client.State = schemas.MCPConnectionStateConnected
336-
m.logger.Debug("%s Per-user OAuth MCP client '%s' restored with %d tools", MCPLogPrefix, config.Name, len(config.DiscoveredTools))
341+
m.logger.Debug("%s Per-user (%s) MCP client '%s' restored with %d tools", MCPLogPrefix, config.AuthType, config.Name, len(config.DiscoveredTools))
337342
} else {
338343
client.State = schemas.MCPConnectionStatePendingTools
339-
m.logger.Debug("%s Per-user OAuth MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.Name)
344+
m.logger.Debug("%s Per-user (%s) MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.AuthType, config.Name)
340345
}
341346
}
342347
m.mu.Unlock()
@@ -492,6 +497,145 @@ func (m *MCPManager) VerifyPerUserOAuthConnection(ctx context.Context, config *s
492497
return tools, toolNameMapping, nil
493498
}
494499

500+
// VerifyHeadersConnection creates a temporary MCP connection using the
501+
// provided user-submitted header values to verify the server is reachable
502+
// and discover available tools. The connection is closed after verification.
503+
//
504+
// Used in two paths:
505+
// - Admin test flow: admin enters sample values during MCP client creation,
506+
// this runs an Initialize handshake against the upstream to validate the
507+
// schema (PerUserHeaderKeys) + discover tools. The discovered tools then
508+
// persist on the MCPClient row; the sample values are discarded.
509+
// - User submission flow: an end user submits their own values via the
510+
// workspace submit URL surfaced inline by MCPAuthRequiredError. The
511+
// handler runs this before upserting the row so a bad submission returns
512+
// 422 immediately instead of failing on the next tool call.
513+
//
514+
// Parameters:
515+
// - config: MCP client configuration (connection URL, name, PerUserHeaderKeys, etc.)
516+
// - userHeaders: caller-supplied header_name → value map (must cover every
517+
// PerUserHeaderKeys entry; the caller validates that before invoking).
518+
//
519+
// Returns:
520+
// - map[string]schemas.ChatTool: discovered tools keyed by prefixed name
521+
// - map[string]string: tool name mapping (sanitized → original MCP name)
522+
// - error: any error during verification
523+
func (m *MCPManager) VerifyHeadersConnection(ctx context.Context, config *schemas.MCPClientConfig, userHeaders map[string]string) (map[string]schemas.ChatTool, map[string]string, error) {
524+
if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" {
525+
return nil, nil, fmt.Errorf("connection URL is required for per-user headers verification")
526+
}
527+
if len(userHeaders) == 0 {
528+
return nil, nil, fmt.Errorf("user headers are required for per-user headers verification")
529+
}
530+
531+
// Build prepared inputs for the typed connect plugin gate. Static admin
532+
// headers (minus Authorization and minus any PerUserHeaderKeys) are
533+
// plugin-visible; user-supplied credentials are layered AFTER PreHooks
534+
// run so plugins cannot read or rewrite them. Mirrors
535+
// VerifyPerUserOAuthConnection's Authorization-injection pattern.
536+
url := config.ConnectionString.GetValue()
537+
preparedHeaders := utils.FlattenHeaders(utils.StaticConfigHeaders(config))
538+
connectReq := &schemas.BifrostMCPConnectRequest{
539+
ClientName: config.Name,
540+
ConnectionType: schemas.MCPConnectionTypeHTTP,
541+
AuthType: config.AuthType,
542+
ConnectionString: &url,
543+
Headers: preparedHeaders,
544+
}
545+
546+
verifyCtx, cancel := context.WithTimeout(ctx, MCPClientConnectionEstablishTimeout)
547+
defer cancel()
548+
gateCtx := schemas.NewBifrostContext(verifyCtx, schemas.NoDeadline)
549+
550+
var tempClient *client.Client
551+
defer func() {
552+
if tempClient != nil {
553+
tempClient.Close()
554+
}
555+
}()
556+
start := time.Now()
557+
558+
_, gateErr := m.runConnectWithPluginPipeline(gateCtx, connectReq, func(preReq *schemas.BifrostMCPConnectRequest) (*schemas.BifrostMCPConnectResponse, error) {
559+
finalURL := url
560+
if preReq.ConnectionString != nil {
561+
finalURL = *preReq.ConnectionString
562+
}
563+
564+
// Copy mutated headers, then layer the user's credential values on
565+
// top. Copying (rather than mutating preReq.Headers in place) avoids
566+
// leaking the values back into the request that PreHook plugins may
567+
// still reference.
568+
finalHeaders := make(map[string]string, len(preReq.Headers)+len(userHeaders))
569+
maps.Copy(finalHeaders, preReq.Headers)
570+
for k, v := range userHeaders {
571+
finalHeaders[k] = v
572+
}
573+
574+
httpTransport, hErr := transport.NewStreamableHTTP(finalURL, transport.WithHTTPHeaders(finalHeaders))
575+
if hErr != nil {
576+
return nil, fmt.Errorf("failed to create HTTP transport for verification: %w", hErr)
577+
}
578+
tempClient = client.NewClient(httpTransport)
579+
if startErr := tempClient.Start(verifyCtx); startErr != nil {
580+
return nil, fmt.Errorf("failed to start MCP connection for verification: %w", startErr)
581+
}
582+
583+
initRequest := mcp.InitializeRequest{
584+
Params: mcp.InitializeParams{
585+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
586+
Capabilities: mcp.ClientCapabilities{},
587+
ClientInfo: mcp.Implementation{
588+
Name: fmt.Sprintf("Bifrost-%s-verify", config.Name),
589+
Version: "1.0.0",
590+
},
591+
},
592+
}
593+
initResult, initErr := tempClient.Initialize(verifyCtx, initRequest)
594+
if initErr != nil {
595+
return nil, fmt.Errorf("failed to initialize MCP connection for verification: %w", initErr)
596+
}
597+
598+
resp := &schemas.BifrostMCPConnectResponse{
599+
ConnectionInfo: &schemas.MCPClientConnectionInfo{
600+
Type: schemas.MCPConnectionTypeHTTP,
601+
ConnectionURL: &finalURL,
602+
},
603+
ExtraFields: schemas.BifrostMCPResponseExtraFields{
604+
Latency: time.Since(start).Milliseconds(),
605+
},
606+
}
607+
if initResult != nil {
608+
resp.ProtocolVersion = initResult.ProtocolVersion
609+
resp.ServerInfo = &schemas.MCPServerInfo{
610+
Name: initResult.ServerInfo.Name,
611+
Version: initResult.ServerInfo.Version,
612+
}
613+
resp.ServerCapabilities = &schemas.MCPServerCapabilities{
614+
Tools: initResult.Capabilities.Tools != nil,
615+
Resources: initResult.Capabilities.Resources != nil,
616+
Prompts: initResult.Capabilities.Prompts != nil,
617+
Logging: initResult.Capabilities.Logging != nil,
618+
}
619+
}
620+
return resp, nil
621+
})
622+
623+
if gateErr != nil {
624+
return nil, nil, fmt.Errorf("failed to verify MCP connection: %s", gateErr.GetErrorString())
625+
}
626+
if tempClient == nil {
627+
return nil, nil, fmt.Errorf("headers verification was short-circuited by plugin; cannot discover tools without a live connection")
628+
}
629+
630+
tools, toolNameMapping, err := m.runListToolsWithHooks(verifyCtx, tempClient, config.Name)
631+
if err != nil {
632+
return nil, nil, fmt.Errorf("failed to discover tools during verification: %w", err)
633+
}
634+
635+
m.logger.Info("%s Per-user headers verification succeeded for '%s': discovered %d tools", MCPLogPrefix, config.Name, len(tools))
636+
return tools, toolNameMapping, nil
637+
}
638+
495639
// SetClientTools updates the tool map and name mapping for an existing client.
496640
// This is used to populate tools discovered during per-user OAuth verification,
497641
// where tool discovery happens separately from client creation.
@@ -798,6 +942,7 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon
798942
ToolSyncInterval: updatedConfig.ToolSyncInterval,
799943
AllowOnAllVirtualKeys: updatedConfig.AllowOnAllVirtualKeys,
800944
Disabled: updatedConfig.Disabled,
945+
PerUserHeaderKeys: slices.Clone(updatedConfig.PerUserHeaderKeys),
801946
}
802947

803948
// Atomically replace the config pointer

core/mcp/codemode/starlark/executecode.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName,
528528
// Acquire a connection through the shared ClientManager abstraction:
529529
// shared-mode clients return their persistent state.Conn (release is a
530530
// no-op); per-user clients get a fresh ephemeral transport that the
531-
// release function closes. Credential errors (e.g. MCPUserOAuthRequiredError)
531+
// release function closes. Credential errors (e.g. MCPAuthRequiredError)
532532
// surface here.
533533
conn, release, err := s.clientManager.AcquireClientConn(nestedCtx, client)
534534
if err != nil {

core/mcp/credstore/credstore.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,18 @@ type CredStore struct {
3131

3232
// NewCredStore constructs the canonical MCPCredentialStore with one resolver
3333
// per known MCPAuthType. The oauth2Provider is injected into the OAuth-
34-
// flavored resolvers only; the None and StaticHeaders resolvers are stateless.
35-
func NewCredStore(oauth2Provider schemas.OAuth2Provider, logger schemas.Logger) *CredStore {
34+
// flavored resolvers only; the None and StaticHeaders resolvers are
35+
// stateless. The headersProvider is injected into the per-user-headers
36+
// resolver — pass nil if the configstore-backed provider isn't wired up
37+
// (the resolver returns a clear error rather than nil-pointering at use).
38+
func NewCredStore(oauth2Provider schemas.OAuth2Provider, headersProvider schemas.MCPHeadersProvider, logger schemas.Logger) *CredStore {
3639
return &CredStore{
3740
resolvers: map[schemas.MCPAuthType]resolver{
38-
schemas.MCPAuthTypeNone: &noneResolver{},
39-
schemas.MCPAuthTypeHeaders: &staticHeadersResolver{},
40-
schemas.MCPAuthTypeOauth: &serverOAuthResolver{provider: oauth2Provider},
41-
schemas.MCPAuthTypePerUserOauth: &perUserOAuthResolver{provider: oauth2Provider},
41+
schemas.MCPAuthTypeNone: &noneResolver{},
42+
schemas.MCPAuthTypeHeaders: &staticHeadersResolver{},
43+
schemas.MCPAuthTypeOauth: &serverOAuthResolver{provider: oauth2Provider},
44+
schemas.MCPAuthTypePerUserOauth: &perUserOAuthResolver{provider: oauth2Provider},
45+
schemas.MCPAuthTypePerUserHeaders: &perUserHeadersResolver{provider: headersProvider},
4246
},
4347
logger: logger,
4448
}

0 commit comments

Comments
 (0)