Skip to content

Commit 2638b13

Browse files
authored
feat: add context.Context to database client interface (#1489)
## Description Fixes: #1488 - Add context.Context as the first parameter to all Client interface methods in the database layer - Update `gormClient`, `fakeClient`, and all HTTP handler/controller call sites to pass context through - Enables request cancellation and tracing to propagate from HTTP handlers and the controller reconciler down to database operations # Testing - Existing unit tests updated to pass `context.Background()` — all pass locally - Fake client updated to match new interface signatures - No behavioral changes; purely a mechanical refactor to follow Go context idioms Signed-off-by: Jeremy Alvis <jeremy.alvis@solo.io>
1 parent b43cf79 commit 2638b13

File tree

17 files changed

+331
-310
lines changed

17 files changed

+331
-310
lines changed

go/api/database/client.go

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package database
22

33
import (
4+
"context"
45
"time"
56

67
"github.com/kagent-dev/kagent/go/api/v1alpha2"
@@ -20,63 +21,63 @@ type LangGraphCheckpointTuple struct {
2021

2122
type Client interface {
2223
// Store methods
23-
StoreFeedback(feedback *Feedback) error
24-
StoreSession(session *Session) error
25-
StoreAgent(agent *Agent) error
26-
StoreTask(task *protocol.Task) error
27-
StorePushNotification(config *protocol.TaskPushNotificationConfig) error
28-
StoreToolServer(toolServer *ToolServer) (*ToolServer, error)
29-
StoreEvents(messages ...*Event) error
24+
StoreFeedback(ctx context.Context, feedback *Feedback) error
25+
StoreSession(ctx context.Context, session *Session) error
26+
StoreAgent(ctx context.Context, agent *Agent) error
27+
StoreTask(ctx context.Context, task *protocol.Task) error
28+
StorePushNotification(ctx context.Context, config *protocol.TaskPushNotificationConfig) error
29+
StoreToolServer(ctx context.Context, toolServer *ToolServer) (*ToolServer, error)
30+
StoreEvents(ctx context.Context, messages ...*Event) error
3031

3132
// Delete methods
32-
DeleteSession(sessionID string, userID string) error
33-
DeleteAgent(agentID string) error
34-
DeleteToolServer(serverName string, groupKind string) error
35-
DeleteTask(taskID string) error
36-
DeletePushNotification(taskID string) error
37-
DeleteToolsForServer(serverName string, groupKind string) error
33+
DeleteSession(ctx context.Context, sessionID string, userID string) error
34+
DeleteAgent(ctx context.Context, agentID string) error
35+
DeleteToolServer(ctx context.Context, serverName string, groupKind string) error
36+
DeleteTask(ctx context.Context, taskID string) error
37+
DeletePushNotification(ctx context.Context, taskID string) error
38+
DeleteToolsForServer(ctx context.Context, serverName string, groupKind string) error
3839

3940
// Get methods
40-
GetSession(sessionID string, userID string) (*Session, error)
41-
GetAgent(name string) (*Agent, error)
42-
GetTask(id string) (*protocol.Task, error)
43-
GetTool(name string) (*Tool, error)
44-
GetToolServer(name string) (*ToolServer, error)
45-
GetPushNotification(taskID string, configID string) (*protocol.TaskPushNotificationConfig, error)
41+
GetSession(ctx context.Context, sessionID string, userID string) (*Session, error)
42+
GetAgent(ctx context.Context, name string) (*Agent, error)
43+
GetTask(ctx context.Context, id string) (*protocol.Task, error)
44+
GetTool(ctx context.Context, name string) (*Tool, error)
45+
GetToolServer(ctx context.Context, name string) (*ToolServer, error)
46+
GetPushNotification(ctx context.Context, taskID string, configID string) (*protocol.TaskPushNotificationConfig, error)
4647

4748
// List methods
48-
ListTools() ([]Tool, error)
49-
ListFeedback(userID string) ([]Feedback, error)
50-
ListTasksForSession(sessionID string) ([]*protocol.Task, error)
51-
ListSessions(userID string) ([]Session, error)
52-
ListSessionsForAgent(agentID string, userID string) ([]Session, error)
53-
ListAgents() ([]Agent, error)
54-
ListToolServers() ([]ToolServer, error)
55-
ListToolsForServer(serverName string, groupKind string) ([]Tool, error)
56-
ListEventsForSession(sessionID, userID string, options QueryOptions) ([]*Event, error)
57-
ListPushNotifications(taskID string) ([]*protocol.TaskPushNotificationConfig, error)
49+
ListTools(ctx context.Context) ([]Tool, error)
50+
ListFeedback(ctx context.Context, userID string) ([]Feedback, error)
51+
ListTasksForSession(ctx context.Context, sessionID string) ([]*protocol.Task, error)
52+
ListSessions(ctx context.Context, userID string) ([]Session, error)
53+
ListSessionsForAgent(ctx context.Context, agentID string, userID string) ([]Session, error)
54+
ListAgents(ctx context.Context) ([]Agent, error)
55+
ListToolServers(ctx context.Context) ([]ToolServer, error)
56+
ListToolsForServer(ctx context.Context, serverName string, groupKind string) ([]Tool, error)
57+
ListEventsForSession(ctx context.Context, sessionID, userID string, options QueryOptions) ([]*Event, error)
58+
ListPushNotifications(ctx context.Context, taskID string) ([]*protocol.TaskPushNotificationConfig, error)
5859

5960
// Helper methods
60-
RefreshToolsForServer(serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error
61+
RefreshToolsForServer(ctx context.Context, serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error
6162

6263
// LangGraph Checkpoint methods
63-
StoreCheckpoint(checkpoint *LangGraphCheckpoint) error
64-
StoreCheckpointWrites(writes []*LangGraphCheckpointWrite) error
65-
ListCheckpoints(userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*LangGraphCheckpointTuple, error)
66-
DeleteCheckpoint(userID, threadID string) error
64+
StoreCheckpoint(ctx context.Context, checkpoint *LangGraphCheckpoint) error
65+
StoreCheckpointWrites(ctx context.Context, writes []*LangGraphCheckpointWrite) error
66+
ListCheckpoints(ctx context.Context, userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*LangGraphCheckpointTuple, error)
67+
DeleteCheckpoint(ctx context.Context, userID, threadID string) error
6768

6869
// CrewAI methods
69-
StoreCrewAIMemory(memory *CrewAIAgentMemory) error
70-
SearchCrewAIMemoryByTask(userID, threadID, taskDescription string, limit int) ([]*CrewAIAgentMemory, error)
71-
ResetCrewAIMemory(userID, threadID string) error
72-
StoreCrewAIFlowState(state *CrewAIFlowState) error
73-
GetCrewAIFlowState(userID, threadID string) (*CrewAIFlowState, error)
70+
StoreCrewAIMemory(ctx context.Context, memory *CrewAIAgentMemory) error
71+
SearchCrewAIMemoryByTask(ctx context.Context, userID, threadID, taskDescription string, limit int) ([]*CrewAIAgentMemory, error)
72+
ResetCrewAIMemory(ctx context.Context, userID, threadID string) error
73+
StoreCrewAIFlowState(ctx context.Context, state *CrewAIFlowState) error
74+
GetCrewAIFlowState(ctx context.Context, userID, threadID string) (*CrewAIFlowState, error)
7475

7576
// Agent memory (vector search) methods
76-
StoreAgentMemory(memory *Memory) error
77-
StoreAgentMemories(memories []*Memory) error
78-
SearchAgentMemory(agentName, userID string, embedding pgvector.Vector, limit int) ([]AgentMemorySearchResult, error)
79-
ListAgentMemories(agentName, userID string) ([]Memory, error)
80-
DeleteAgentMemory(agentName, userID string) error
81-
PruneExpiredMemories() error
77+
StoreAgentMemory(ctx context.Context, memory *Memory) error
78+
StoreAgentMemories(ctx context.Context, memories []*Memory) error
79+
SearchAgentMemory(ctx context.Context, agentName, userID string, embedding pgvector.Vector, limit int) ([]AgentMemorySearchResult, error)
80+
ListAgentMemories(ctx context.Context, agentName, userID string) ([]Memory, error)
81+
DeleteAgentMemory(ctx context.Context, agentName, userID string) error
82+
PruneExpiredMemories(ctx context.Context) error
8283
}

go/core/internal/controller/reconciler/reconciler.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (a *kagentReconciler) ReconcileKagentAgent(ctx context.Context, req ctrl.Re
8585
agent := &v1alpha2.Agent{}
8686
if err := a.kube.Get(ctx, req.NamespacedName, agent); err != nil {
8787
if apierrors.IsNotFound(err) {
88-
return a.handleAgentDeletion(req)
88+
return a.handleAgentDeletion(ctx, req)
8989
}
9090

9191
return fmt.Errorf("failed to get agent %s: %w", req.NamespacedName, err)
@@ -99,9 +99,9 @@ func (a *kagentReconciler) ReconcileKagentAgent(ctx context.Context, req ctrl.Re
9999
return a.reconcileAgentStatus(ctx, agent, err)
100100
}
101101

102-
func (a *kagentReconciler) handleAgentDeletion(req ctrl.Request) error {
102+
func (a *kagentReconciler) handleAgentDeletion(ctx context.Context, req ctrl.Request) error {
103103
id := utils.ConvertToPythonIdentifier(req.String())
104-
if err := a.dbClient.DeleteAgent(id); err != nil {
104+
if err := a.dbClient.DeleteAgent(ctx, id); err != nil {
105105
return fmt.Errorf("failed to delete agent %s: %w",
106106
req.String(), err)
107107
}
@@ -205,11 +205,11 @@ func (a *kagentReconciler) ReconcileKagentMCPService(ctx context.Context, req ct
205205
Name: req.String(),
206206
GroupKind: schema.GroupKind{Group: "", Kind: "Service"}.String(),
207207
}
208-
if err := a.dbClient.DeleteToolServer(dbService.Name, dbService.GroupKind); err != nil {
208+
if err := a.dbClient.DeleteToolServer(ctx, dbService.Name, dbService.GroupKind); err != nil {
209209
reconcileLog.Error(err, "failed to delete tool server for mcp service", "service", req.String())
210210
}
211211
reconcileLog.Info("mcp service was deleted", "service", req.String())
212-
if err := a.dbClient.DeleteToolsForServer(dbService.Name, dbService.GroupKind); err != nil {
212+
if err := a.dbClient.DeleteToolsForServer(ctx, dbService.Name, dbService.GroupKind); err != nil {
213213
reconcileLog.Error(err, "failed to delete tools for mcp service", "service", req.String())
214214
}
215215
return nil
@@ -378,11 +378,11 @@ func (a *kagentReconciler) ReconcileKagentMCPServer(ctx context.Context, req ctr
378378
Name: req.String(),
379379
GroupKind: schema.GroupKind{Group: "kagent.dev", Kind: "MCPServer"}.String(),
380380
}
381-
if err := a.dbClient.DeleteToolServer(dbServer.Name, dbServer.GroupKind); err != nil {
381+
if err := a.dbClient.DeleteToolServer(ctx, dbServer.Name, dbServer.GroupKind); err != nil {
382382
reconcileLog.Error(err, "failed to delete tool server for mcp server", "mcpServer", req.String())
383383
}
384384
reconcileLog.Info("mcp server was deleted", "mcpServer", req.String())
385-
if err := a.dbClient.DeleteToolsForServer(dbServer.Name, dbServer.GroupKind); err != nil {
385+
if err := a.dbClient.DeleteToolsForServer(ctx, dbServer.Name, dbServer.GroupKind); err != nil {
386386
reconcileLog.Error(err, "failed to delete tools for mcp server", "mcpServer", req.String())
387387
}
388388
return nil
@@ -428,11 +428,11 @@ func (a *kagentReconciler) ReconcileKagentRemoteMCPServer(ctx context.Context, r
428428
GroupKind: schema.GroupKind{Group: "kagent.dev", Kind: "RemoteMCPServer"}.String(),
429429
}
430430

431-
if err := a.dbClient.DeleteToolServer(dbServer.Name, dbServer.GroupKind); err != nil {
431+
if err := a.dbClient.DeleteToolServer(ctx, dbServer.Name, dbServer.GroupKind); err != nil {
432432
l.Error(err, "failed to delete tool server for remote mcp server")
433433
}
434434

435-
if err := a.dbClient.DeleteToolsForServer(dbServer.Name, dbServer.GroupKind); err != nil {
435+
if err := a.dbClient.DeleteToolsForServer(ctx, dbServer.Name, dbServer.GroupKind); err != nil {
436436
l.Error(err, "failed to delete tools for remote mcp server")
437437
}
438438

@@ -835,15 +835,15 @@ func (a *kagentReconciler) upsertAgent(ctx context.Context, agent *v1alpha2.Agen
835835
Config: agentOutputs.Config,
836836
}
837837

838-
if err := a.dbClient.StoreAgent(dbAgent); err != nil {
838+
if err := a.dbClient.StoreAgent(ctx, dbAgent); err != nil {
839839
return fmt.Errorf("failed to store agent %s: %w", id, err)
840840
}
841841

842842
return nil
843843
}
844844

845845
func (a *kagentReconciler) upsertToolServerForRemoteMCPServer(ctx context.Context, toolServer *database.ToolServer, remoteMcpServer *v1alpha2.RemoteMCPServer) ([]*v1alpha2.MCPTool, error) {
846-
if _, err := a.dbClient.StoreToolServer(toolServer); err != nil {
846+
if _, err := a.dbClient.StoreToolServer(ctx, toolServer); err != nil {
847847
return nil, fmt.Errorf("failed to store toolServer %s: %w", toolServer.Name, err)
848848
}
849849

@@ -858,7 +858,7 @@ func (a *kagentReconciler) upsertToolServerForRemoteMCPServer(ctx context.Contex
858858
}
859859

860860
// Refresh tools in database - uses transaction for atomicity
861-
if err := a.dbClient.RefreshToolsForServer(toolServer.Name, toolServer.GroupKind, tools...); err != nil {
861+
if err := a.dbClient.RefreshToolsForServer(ctx, toolServer.Name, toolServer.GroupKind, tools...); err != nil {
862862
return nil, fmt.Errorf("failed to refresh tools for toolServer %s: %w", toolServer.Name, err)
863863
}
864864

@@ -956,7 +956,7 @@ func (a *kagentReconciler) listTools(ctx context.Context, tsp mcp.Transport, too
956956

957957
func (a *kagentReconciler) getDiscoveredMCPTools(ctx context.Context, serverRef string) ([]*v1alpha2.MCPTool, error) {
958958
// This function is currently only used for RemoteMCPServer
959-
allTools, err := a.dbClient.ListToolsForServer(serverRef, schema.GroupKind{Group: "kagent.dev", Kind: "RemoteMCPServer"}.String())
959+
allTools, err := a.dbClient.ListToolsForServer(ctx, serverRef, schema.GroupKind{Group: "kagent.dev", Kind: "RemoteMCPServer"}.String())
960960
if err != nil {
961961
return nil, err
962962
}

0 commit comments

Comments
 (0)