@@ -17,15 +17,21 @@ import (
17
17
mcpclient "github.com/docker/mcp-gateway/cmd/docker-mcp/internal/mcp"
18
18
)
19
19
20
+ type clientKey struct {
21
+ serverName string
22
+ session * mcp.ServerSession
23
+ }
24
+
20
25
type keptClient struct {
21
- Name string
22
- Getter * clientGetter
23
- Config catalog.ServerConfig
26
+ Name string
27
+ Getter * clientGetter
28
+ Config * catalog.ServerConfig
29
+ ClientConfig * clientConfig
24
30
}
25
31
26
32
type clientPool struct {
27
33
Options
28
- keptClients [ ]keptClient
34
+ keptClients map [ clientKey ]keptClient
29
35
clientLock sync.RWMutex
30
36
networks []string
31
37
docker docker.Client
@@ -41,20 +47,42 @@ func newClientPool(options Options, docker docker.Client) *clientPool {
41
47
return & clientPool {
42
48
Options : options ,
43
49
docker : docker ,
44
- keptClients : []keptClient {},
50
+ keptClients : make (map [clientKey ]keptClient ),
51
+ }
52
+ }
53
+
54
+ func (cp * clientPool ) UpdateRoots (ss * mcp.ServerSession , roots []* mcp.Root ) {
55
+ cp .clientLock .RLock ()
56
+ defer cp .clientLock .RUnlock ()
57
+
58
+ for _ , kc := range cp .keptClients {
59
+ if kc .ClientConfig != nil && (kc .ClientConfig .serverSession == ss ) {
60
+ client , err := kc .Getter .GetClient (context .TODO ()) // should be cached
61
+ if err == nil {
62
+ client .AddRoots (roots )
63
+ }
64
+ }
45
65
}
46
66
}
47
67
48
- func (cp * clientPool ) AcquireClient (ctx context.Context , serverConfig catalog.ServerConfig , config * clientConfig ) (mcpclient.Client , error ) {
68
+ func (cp * clientPool ) longLived (serverConfig * catalog.ServerConfig , config * clientConfig ) bool {
69
+ keep := config != nil && config .serverSession != nil && (serverConfig .Spec .LongLived || cp .LongLived )
70
+ return keep
71
+ }
72
+
73
+ func (cp * clientPool ) AcquireClient (ctx context.Context , serverConfig * catalog.ServerConfig , config * clientConfig ) (mcpclient.Client , error ) {
49
74
var getter * clientGetter
75
+ c := ctx
50
76
51
77
// Check if client is kept, can be returned immediately
78
+ var session * mcp.ServerSession
79
+ if config != nil {
80
+ session = config .serverSession
81
+ }
82
+ key := clientKey {serverName : serverConfig .Name , session : session }
52
83
cp .clientLock .RLock ()
53
- for _ , kc := range cp .keptClients {
54
- if kc .Name == serverConfig .Name {
55
- getter = kc .Getter
56
- break
57
- }
84
+ if kc , exists := cp .keptClients [key ]; exists {
85
+ getter = kc .Getter
58
86
}
59
87
cp .clientLock .RUnlock ()
60
88
@@ -63,30 +91,27 @@ func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.Se
63
91
getter = newClientGetter (serverConfig , cp , config )
64
92
65
93
// If the client is long running, save it for later
66
- if serverConfig .Spec .LongLived || cp .LongLived {
94
+ if cp .longLived (serverConfig , config ) {
95
+ c = context .Background ()
67
96
cp .clientLock .Lock ()
68
- cp .keptClients = append (cp .keptClients , keptClient {
69
- Name : serverConfig .Name ,
70
- Getter : getter ,
71
- Config : serverConfig ,
72
- })
97
+ cp .keptClients [key ] = keptClient {
98
+ Name : serverConfig .Name ,
99
+ Getter : getter ,
100
+ Config : serverConfig ,
101
+ ClientConfig : config ,
102
+ }
73
103
cp .clientLock .Unlock ()
74
104
}
75
105
}
76
106
77
- client , err := getter .GetClient (ctx ) // first time creates the client, can take some time
107
+ client , err := getter .GetClient (c ) // first time creates the client, can take some time
78
108
if err != nil {
79
109
cp .clientLock .Lock ()
80
110
defer cp .clientLock .Unlock ()
81
111
82
112
// Wasn't successful, remove it
83
- if serverConfig .Spec .LongLived || cp .LongLived {
84
- for i , kc := range cp .keptClients {
85
- if kc .Getter == getter {
86
- cp .keptClients = append (cp .keptClients [:i ], cp .keptClients [i + 1 :]... )
87
- break
88
- }
89
- }
113
+ if cp .longLived (serverConfig , config ) {
114
+ delete (cp .keptClients , key )
90
115
}
91
116
92
117
return nil , err
@@ -111,14 +136,12 @@ func (cp *clientPool) ReleaseClient(client mcpclient.Client) {
111
136
client .Session ().Close ()
112
137
return
113
138
}
114
-
115
- // Otherwise, leave the client as is
116
139
}
117
140
118
141
func (cp * clientPool ) Close () {
119
142
cp .clientLock .Lock ()
120
143
existingMap := cp .keptClients
121
- cp .keptClients = [ ]keptClient {}
144
+ cp .keptClients = make ( map [ clientKey ]keptClient )
122
145
cp .clientLock .Unlock ()
123
146
124
147
// Close all clients
@@ -215,7 +238,7 @@ func (cp *clientPool) baseArgs(name string) []string {
215
238
return args
216
239
}
217
240
218
- func (cp * clientPool ) argsAndEnv (serverConfig catalog.ServerConfig , readOnly * bool , targetConfig proxies.TargetConfig ) ([]string , []string ) {
241
+ func (cp * clientPool ) argsAndEnv (serverConfig * catalog.ServerConfig , readOnly * bool , targetConfig proxies.TargetConfig ) ([]string , []string ) {
219
242
args := cp .baseArgs (serverConfig .Name )
220
243
var env []string
221
244
@@ -308,13 +331,13 @@ type clientGetter struct {
308
331
client mcpclient.Client
309
332
err error
310
333
311
- serverConfig catalog.ServerConfig
334
+ serverConfig * catalog.ServerConfig
312
335
cp * clientPool
313
336
314
337
clientConfig * clientConfig
315
338
}
316
339
317
- func newClientGetter (serverConfig catalog.ServerConfig , cp * clientPool , config * clientConfig ) * clientGetter {
340
+ func newClientGetter (serverConfig * catalog.ServerConfig , cp * clientPool , config * clientConfig ) * clientGetter {
318
341
return & clientGetter {
319
342
serverConfig : serverConfig ,
320
343
cp : cp ,
@@ -388,6 +411,7 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
388
411
// ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
389
412
// defer cancel()
390
413
414
+ // TODO add initial roots
391
415
if err := client .Initialize (ctx , initParams , cg .cp .Verbose , ss , server ); err != nil {
392
416
return nil , err
393
417
}
0 commit comments