Skip to content

Commit 44d3180

Browse files
authored
feat: add ability to use a custom session store MCP-451 (#1024)
1 parent c01052f commit 44d3180

9 files changed

Lines changed: 305 additions & 106 deletions

File tree

src/common/sessionStore.ts

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,20 @@ export type CloseableTransport = {
1515

1616
export type SessionCloseReason = "idle_timeout" | "transport_closed" | "server_stop" | "unknown";
1717

18-
export class SessionStore<T extends CloseableTransport = CloseableTransport> {
18+
/**
19+
* Interface for managing MCP transport sessions.
20+
*
21+
* Implement this interface to provide custom session storage and lifecycle
22+
* management (e.g. database-based session storage).
23+
*/
24+
export interface ISessionStore<T extends CloseableTransport = CloseableTransport> {
25+
getSession(sessionId: string): T | undefined;
26+
addSession(params: { sessionId: string; transport: T; logger: LoggerBase }): void;
27+
closeSession(params: { sessionId: string; reason?: SessionCloseReason }): Promise<void>;
28+
closeAllSessions(): Promise<void>;
29+
}
30+
31+
export class SessionStore<T extends CloseableTransport = CloseableTransport> implements ISessionStore<T> {
1932
private sessions: {
2033
[sessionId: string]: {
2134
logger: LoggerBase;
@@ -25,19 +38,29 @@ export class SessionStore<T extends CloseableTransport = CloseableTransport> {
2538
};
2639
} = {};
2740

28-
constructor(
29-
private readonly idleTimeoutMS: number,
30-
private readonly notificationTimeoutMS: number,
31-
private readonly logger: LoggerBase,
32-
private readonly metrics: Metrics<DefaultMetrics>
33-
) {
34-
if (idleTimeoutMS <= 0) {
41+
private readonly idleTimeoutMS: number;
42+
private readonly notificationTimeoutMS: number;
43+
private readonly logger: LoggerBase;
44+
private readonly metrics: Metrics<DefaultMetrics>;
45+
46+
constructor(params: {
47+
options: { idleTimeoutMS: number; notificationTimeoutMS: number };
48+
logger: LoggerBase;
49+
metrics: Metrics<DefaultMetrics>;
50+
}) {
51+
const { options, logger, metrics } = params;
52+
this.idleTimeoutMS = options.idleTimeoutMS;
53+
this.notificationTimeoutMS = options.notificationTimeoutMS;
54+
this.logger = logger;
55+
this.metrics = metrics;
56+
57+
if (this.idleTimeoutMS <= 0) {
3558
throw new Error("idleTimeoutMS must be greater than 0");
3659
}
37-
if (notificationTimeoutMS <= 0) {
60+
if (this.notificationTimeoutMS <= 0) {
3861
throw new Error("notificationTimeoutMS must be greater than 0");
3962
}
40-
if (idleTimeoutMS <= notificationTimeoutMS) {
63+
if (this.idleTimeoutMS <= this.notificationTimeoutMS) {
4164
throw new Error("idleTimeoutMS must be greater than notificationTimeoutMS");
4265
}
4366
}
@@ -75,7 +98,8 @@ export class SessionStore<T extends CloseableTransport = CloseableTransport> {
7598
});
7699
}
77100

78-
setSession(sessionId: string, transport: T, logger: LoggerBase): void {
101+
addSession(params: { sessionId: string; transport: T; logger: LoggerBase }): void {
102+
const { sessionId, transport, logger } = params;
79103
const session = this.sessions[sessionId];
80104
if (session) {
81105
throw new Error(`Session ${sessionId} already exists`);
@@ -146,3 +170,31 @@ export class SessionStore<T extends CloseableTransport = CloseableTransport> {
146170
);
147171
}
148172
}
173+
174+
/**
175+
* Constructor arguments for creating a SessionStore instance.
176+
*/
177+
export type SessionStoreConstructorArgs<TMetrics extends DefaultMetrics = DefaultMetrics> = {
178+
options: { idleTimeoutMS: number; notificationTimeoutMS: number };
179+
logger: LoggerBase;
180+
metrics: Metrics<TMetrics>;
181+
};
182+
183+
/**
184+
* A function to create a custom SessionStore instance.
185+
* When provided, the runner will use this function instead of the default SessionStore constructor.
186+
*/
187+
export type CreateSessionStoreFn<
188+
TTransport extends CloseableTransport = CloseableTransport,
189+
TMetrics extends DefaultMetrics = DefaultMetrics,
190+
> = (args: SessionStoreConstructorArgs<TMetrics>) => ISessionStore<TTransport>;
191+
192+
/**
193+
* Creates a default SessionStore instance from the provided constructor arguments.
194+
*/
195+
export function createDefaultSessionStore<
196+
TTransport extends CloseableTransport = CloseableTransport,
197+
TMetrics extends DefaultMetrics = DefaultMetrics,
198+
>(params: SessionStoreConstructorArgs<TMetrics>): SessionStore<TTransport> {
199+
return new SessionStore(params);
200+
}

src/lib.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,15 @@ export { Keychain, registerGlobalSecretToRedact } from "./common/keychain.js";
7272
export type { Secret } from "./common/keychain.js";
7373
export { Elicitation } from "./elicitation.js";
7474
export { applyConfigOverrides, ConfigOverrideError } from "./common/config/configOverrides.js";
75-
export { type CloseableTransport, type SessionCloseReason } from "./common/sessionStore.js";
75+
export {
76+
SessionStore,
77+
createDefaultSessionStore,
78+
type ISessionStore,
79+
type CloseableTransport,
80+
type SessionCloseReason,
81+
type CreateSessionStoreFn,
82+
type SessionStoreConstructorArgs,
83+
} from "./common/sessionStore.js";
7684
export { ApiClient, type ApiClientOptions } from "./common/atlas/apiClient.js";
7785
export type { AuthProvider } from "./common/atlas/auth/authProvider.js";
7886
export { type UIRegistryOptions } from "./ui/registry/registry.js";

src/tools/mongodb/metadata/collectionIndexes.ts

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,51 @@ export class CollectionIndexesTool extends MongoDBToolBase {
101101
protected extractSearchIndexDetails(indexes: Record<string, unknown>[]): SearchIndexStatus[] {
102102
return indexes.map((index) => ({
103103
name: (index["name"] ?? "default") as string,
104-
type: (index["type"] ?? "UNKNOWN") as string,
104+
type: CollectionIndexesTool.resolveIndexType(index),
105105
status: (index["status"] ?? "UNKNOWN") as string,
106106
queryable: (index["queryable"] ?? false) as boolean,
107107
latestDefinition: (index["latestDefinition"] ?? {}) as Record<string, unknown>,
108108
}));
109109
}
110+
111+
/**
112+
* Resolves the search index type from the index document, falling back to
113+
* definition structure inference when the server doesn't provide a top-level
114+
* `type` field.
115+
*/
116+
private static resolveIndexType(index: Record<string, unknown>): string {
117+
// Direct type from server response.
118+
// TODO: This is undocumented and is not always present, should be removed in the future.
119+
const serverType = index["type"];
120+
if (serverType && typeof serverType === "string") {
121+
return serverType;
122+
}
123+
124+
const definition = (index["latestDefinition"] ?? {}) as Record<string, unknown>;
125+
const defType = definition["type"];
126+
if (defType && typeof defType === "string") {
127+
return defType;
128+
}
129+
130+
// Vector search uses a `fields` array, Atlas search uses `mappings`
131+
const fields = definition["fields"];
132+
if (Array.isArray(fields)) {
133+
// Check for auto-embed indexes (have autoEmbed field type)
134+
if (fields.some((field: Record<string, unknown>) => field["type"] === "autoEmbed")) {
135+
return "autoEmbed";
136+
}
137+
// Check for regular vector search indexes (have vector field type)
138+
if (fields.some((field: Record<string, unknown>) => field["type"] === "vector")) {
139+
return "vectorSearch";
140+
}
141+
// Other vector search variations (e.g., mixed with filter fields only)
142+
return "vectorSearch";
143+
}
144+
145+
if (definition["mappings"] !== undefined && definition["mappings"] !== null) {
146+
return "search";
147+
}
148+
149+
return "UNKNOWN";
150+
}
110151
}

src/transports/streamableHttp.ts

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import type {
77
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
88
import type { LoggerBase } from "../common/logging/index.js";
99
import { CompositeLogger, LogId } from "../common/logging/index.js";
10-
import { SessionStore } from "../common/sessionStore.js";
10+
import { type ISessionStore, type CreateSessionStoreFn, createDefaultSessionStore } from "../common/sessionStore.js";
1111
import {
1212
TransportRunnerBase,
1313
type TransportRunnerConfig,
@@ -31,34 +31,6 @@ export type MonitoringServerConfig = {
3131
monitoringServerFeatures: MonitoringServerFeature[];
3232
};
3333

34-
/**
35-
* Constructor arguments for creating a MonitoringServer instance.
36-
*/
37-
export type MonitoringServerConstructorArgs<TMetrics extends DefaultMetrics = DefaultMetrics> = {
38-
host: string;
39-
port: number;
40-
features: MonitoringServerFeature[];
41-
logger: LoggerBase;
42-
metrics: Metrics<TMetrics>;
43-
};
44-
45-
/**
46-
* A function to create a custom MonitoringServer instance.
47-
* When provided, the runner will use this function instead of the default MonitoringServer constructor.
48-
*/
49-
export type CreateMonitoringServerFn<TMetrics extends DefaultMetrics = DefaultMetrics> = (
50-
args: MonitoringServerConstructorArgs<TMetrics>
51-
) => MonitoringServer<TMetrics> | undefined;
52-
53-
/**
54-
* Creates a default MonitoringServer instance from the provided constructor arguments.
55-
*/
56-
export const createDefaultMonitoringServer: <TMetrics extends DefaultMetrics = DefaultMetrics>(
57-
args: MonitoringServerConstructorArgs<TMetrics>
58-
) => MonitoringServer<TMetrics> = <TMetrics extends DefaultMetrics = DefaultMetrics>(
59-
args: MonitoringServerConstructorArgs<TMetrics>
60-
) => new MonitoringServer<TMetrics>(args);
61-
6234
/**
6335
* Configuration options for the StreamableHttpRunner.
6436
* Extends the base TransportRunnerConfig with HTTP-transport-specific options.
@@ -77,6 +49,15 @@ export type StreamableHttpTransportRunnerConfig<
7749
* receiving the constructor arguments that would normally be used.
7850
*/
7951
createMonitoringServer?: CreateMonitoringServerFn<TMetrics>;
52+
53+
/**
54+
* When provided, the runner will use this function to create the session store
55+
* instead of using the default SessionStore constructor. This allows for
56+
* customizing session storage (e.g., Redis-backed storage, custom timeout behavior,
57+
* or shared session state across instances) while still receiving the constructor
58+
* arguments that would normally be used.
59+
*/
60+
createSessionStore?: CreateSessionStoreFn<StreamableHTTPServerTransport, TMetrics>;
8061
};
8162

8263
const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000;
@@ -93,23 +74,30 @@ export class StreamableHttpRunner<
9374
> extends TransportRunnerBase<TUserConfig, TContext, TMetrics> {
9475
private mcpServer: MCPHttpServer<TUserConfig, TContext> | undefined;
9576
private readonly monitoringServer: MonitoringServer<TMetrics> | undefined;
77+
private readonly sessionStore: ISessionStore<StreamableHTTPServerTransport>;
9678

9779
constructor(config: StreamableHttpTransportRunnerConfig<TUserConfig, TMetrics>) {
9880
super(config);
81+
82+
this.sessionStore = (config.createSessionStore ?? createDefaultSessionStore<StreamableHTTPServerTransport>)({
83+
options: {
84+
idleTimeoutMS: this.userConfig.idleTimeoutMs,
85+
notificationTimeoutMS: this.userConfig.notificationTimeoutMs,
86+
},
87+
logger: this.logger,
88+
metrics: this.metrics,
89+
});
9990
// Create monitoring server if host/port are configured
10091
const host = config.userConfig.monitoringServerHost ?? config.userConfig.healthCheckHost;
10192
const port = config.userConfig.monitoringServerPort ?? config.userConfig.healthCheckPort;
10293
if (host !== undefined && port !== undefined) {
103-
const args: MonitoringServerConstructorArgs<TMetrics> = {
94+
this.monitoringServer = (config.createMonitoringServer ?? createDefaultMonitoringServer)({
10495
host,
10596
port,
10697
features: config.userConfig.monitoringServerFeatures,
10798
logger: this.logger,
10899
metrics: this.metrics,
109-
};
110-
this.monitoringServer = (config.createMonitoringServer ?? createDefaultMonitoringServer)(args);
111-
} else {
112-
this.monitoringServer = undefined;
100+
});
113101
}
114102
}
115103

@@ -131,6 +119,7 @@ export class StreamableHttpRunner<
131119
this.createServerForRequest({ request, serverOptions, sessionOptions }),
132120
logger: this.logger,
133121
metrics: this.metrics,
122+
sessionStore: this.sessionStore,
134123
});
135124
await this.mcpServer.start();
136125

@@ -340,7 +329,7 @@ abstract class ExpressBasedHttpServer {
340329
}
341330

342331
class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unknown> extends ExpressBasedHttpServer {
343-
private sessionStore!: SessionStore<StreamableHTTPServerTransport>;
332+
private readonly sessionStore: ISessionStore<StreamableHTTPServerTransport>;
344333
private readonly serverOptions?: CustomizableServerOptions<TUserConfig, TContext>;
345334
private readonly sessionOptions?: CustomizableSessionOptions<TUserConfig>;
346335
private readonly userConfig: UserConfig;
@@ -359,6 +348,7 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
359348
sessionOptions,
360349
logger,
361350
metrics,
351+
sessionStore,
362352
}: {
363353
userConfig: TUserConfig;
364354
createServerForRequest: (createParams: {
@@ -370,6 +360,7 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
370360
serverOptions?: CustomizableServerOptions<TUserConfig, TContext>;
371361
sessionOptions?: CustomizableSessionOptions<TUserConfig>;
372362
metrics: Metrics<DefaultMetrics>;
363+
sessionStore: ISessionStore<StreamableHTTPServerTransport>;
373364
}) {
374365
super({
375366
port: userConfig.httpPort,
@@ -382,6 +373,7 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
382373
this.createServerForRequest = createServerForRequest;
383374
this.userConfig = userConfig;
384375
this.metrics = metrics;
376+
this.sessionStore = sessionStore;
385377
}
386378

387379
public async stop(): Promise<void> {
@@ -472,13 +464,6 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
472464
protected override async setupRoutes(): Promise<void> {
473465
const { StreamableHTTPServerTransport } = await import("@modelcontextprotocol/sdk/server/streamableHttp.js");
474466

475-
this.sessionStore = new SessionStore(
476-
this.userConfig.idleTimeoutMs,
477-
this.userConfig.notificationTimeoutMs,
478-
this.logger,
479-
this.metrics
480-
);
481-
482467
this.app.use(express.json({ limit: this.userConfig.httpBodyLimit }));
483468
this.app.use((req, res, next) => {
484469
for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) {
@@ -589,7 +574,7 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
589574
// reuse it. This may cause issues if server.connect fails as we'll try to use a transport that's
590575
// not fully set up.
591576
server.session.logger.setAttribute("sessionId", sessionId);
592-
this.sessionStore.setSession(sessionId, transport, server.session.logger);
577+
this.sessionStore.addSession({ sessionId, transport, logger: server.session.logger });
593578

594579
const keepAliveLoop = this.startKeepAliveLoop(transport, server);
595580
transport.onclose = (): void => {
@@ -677,6 +662,34 @@ class MCPHttpServer<TUserConfig extends UserConfig = UserConfig, TContext = unkn
677662
}
678663
}
679664

665+
/**
666+
* Constructor arguments for creating a MonitoringServer instance.
667+
*/
668+
export type MonitoringServerConstructorArgs<TMetrics extends DefaultMetrics = DefaultMetrics> = {
669+
host: string;
670+
port: number;
671+
features: MonitoringServerFeature[];
672+
logger: LoggerBase;
673+
metrics: Metrics<TMetrics>;
674+
};
675+
676+
/**
677+
* A function to create a custom MonitoringServer instance.
678+
* When provided, the runner will use this function instead of the default MonitoringServer constructor.
679+
*/
680+
export type CreateMonitoringServerFn<TMetrics extends DefaultMetrics = DefaultMetrics> = (
681+
args: MonitoringServerConstructorArgs<TMetrics>
682+
) => MonitoringServer<TMetrics> | undefined;
683+
684+
/**
685+
* Creates a default MonitoringServer instance from the provided constructor arguments.
686+
*/
687+
export const createDefaultMonitoringServer: <TMetrics extends DefaultMetrics = DefaultMetrics>(
688+
args: MonitoringServerConstructorArgs<TMetrics>
689+
) => MonitoringServer<TMetrics> = <TMetrics extends DefaultMetrics = DefaultMetrics>(
690+
args: MonitoringServerConstructorArgs<TMetrics>
691+
) => new MonitoringServer<TMetrics>(args);
692+
680693
export class MonitoringServer<TMetrics extends DefaultMetrics = DefaultMetrics> extends ExpressBasedHttpServer {
681694
private readonly features: MonitoringServerFeature[];
682695
private readonly metrics: Metrics<TMetrics>;

tests/integration/tools/mongodb/create/createIndex.test.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,6 @@ describeWithMongoDB(
515515
const indexes = (await collection.listSearchIndexes().toArray()) as unknown as Document[];
516516
expect(indexes).toHaveLength(1);
517517
expect(indexes[0]?.name).toEqual("vector_1_vector");
518-
expect(indexes[0]?.type).toEqual("vectorSearch");
519518
expect(indexes[0]?.status).toEqual(expect.stringMatching(/PENDING|BUILDING/));
520519
expect(indexes[0]?.queryable).toEqual(false);
521520
expect(indexes[0]?.latestDefinition).toEqual({
@@ -702,14 +701,9 @@ describeWithMongoDB(
702701
const indexes: Document[] = await collection.listSearchIndexes().toArray();
703702
expect(indexes).toHaveLength(1);
704703
expect(indexes[0]?.name).toEqual("vector_1_vector_auto_embed");
705-
expect(indexes[0]?.type).toEqual("vectorSearch");
706-
// Note: The status reporting here is because of an internal feature
707-
// flag. For auto-embed indexes we still don't have status
708-
// reporting.
709704
expect(indexes[0]?.status).toEqual(expect.stringMatching(/PENDING|BUILDING/));
710705
expect(indexes[0]?.latestDefinition).toEqual(
711706
expect.objectContaining({
712-
type: "vectorSearch",
713707
fields: [{ type: "autoEmbed", path: "plot", model: "voyage-4-large", modality: "text" }],
714708
})
715709
);

tests/integration/tools/mongodb/metadata/collectionIndexes.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ describeWithMongoDB(
437437
const vectorIndexDefinition = indexDefinitions.find((def) => def.name === "my-auto-embed-index");
438438
expectDefined(vectorIndexDefinition);
439439
expect(vectorIndexDefinition).toHaveProperty("name", "my-auto-embed-index");
440-
expect(vectorIndexDefinition).toHaveProperty("type", "vectorSearch");
440+
expect(vectorIndexDefinition).toHaveProperty("type", "autoEmbed");
441441

442442
const fields0 = vectorIndexDefinition.latestDefinition.fields;
443443
expect(fields0).toHaveLength(1);

0 commit comments

Comments
 (0)