diff --git a/.changeset/custom-method-overloads.md b/.changeset/custom-method-overloads.md new file mode 100644 index 000000000..0f416940e --- /dev/null +++ b/.changeset/custom-method-overloads.md @@ -0,0 +1,6 @@ +--- +'@modelcontextprotocol/client': minor +'@modelcontextprotocol/server': minor +--- + +`setRequestHandler`/`setNotificationHandler` accept the v1 `(ZodSchema, handler)` form as a first-class alternative to `(methodString, handler)`. `request()` and `ctx.mcpReq.send()` accept an explicit result schema (`request(req, resultSchema, options?)`) and have method-keyed return types for spec methods. `callTool(params, resultSchema?)` accepts the v1 schema arg (ignored). `removeRequestHandler`/`removeNotificationHandler`/`assertCanSetRequestHandler` accept any method string. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index a37b5e206..2c753f4ab 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -204,7 +204,7 @@ if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) ``` **Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all -Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). +Zod schemas, all callback return types. Note: `callTool()` and `request()` schema parameter is now optional (see section 11). ## 6. McpServer API Changes @@ -340,7 +340,7 @@ The server package now exports framework-agnostic alternatives: `validateHostHea ## 9. `setRequestHandler` / `setNotificationHandler` API -The low-level handler registration methods now take a method string instead of a Zod schema. +The low-level handler registration methods now accept a method string in addition to the v1 Zod-schema form (both are supported). ```typescript // v1: schema-based @@ -377,6 +377,15 @@ Schema to method string mapping: Request/notification params remain fully typed. Remove unused schema imports after migration. +**Custom (non-standard) methods** — vendor extensions or sub-protocols whose method strings are not in the MCP spec — work on `Client`/`Server` directly using the same v1 Zod-schema form: + +| Form | Notes | +| ------------------------------------------------------------ | --------------------------------------------------------------------- | +| `setRequestHandler(CustomReqSchema, (req, ctx) => ...)` | unchanged | +| `setNotificationHandler(CustomNotifSchema, n => ...)` | unchanged | +| `this.request({ method: 'vendor/x', params }, ResultSchema)` | unchanged | +| `this.notification({ method: 'vendor/x', params })` | unchanged | + ## 10. Request Handler Context Types `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. @@ -407,9 +416,9 @@ Request/notification params remain fully typed. Remove unused schema imports aft | `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | | `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | -## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` +## 11. Schema parameter on `request()` / `callTool()` / `mcpReq.send()` is optional -`Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` no longer take a Zod result schema argument. The SDK resolves the schema internally from the method name. +`Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` still accept a result schema as the second argument (the v1 form), but for spec methods it is optional — the SDK resolves the schema internally from the method name. The schema argument remains the supported call form for custom (non-spec) methods. ```typescript // v1: schema required @@ -418,22 +427,22 @@ const result = await client.request({ method: 'tools/call', params: { ... } }, C const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { ... } }, ElicitResultSchema); const tool = await client.callTool({ name: 'my-tool', arguments: {} }, CompatibilityCallToolResultSchema); -// v2: no schema argument +// v2: schema optional on request()/callTool()/mcpReq.send() for spec methods const result = await client.request({ method: 'tools/call', params: { ... } }); const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { ... } }); const tool = await client.callTool({ name: 'my-tool', arguments: {} }); ``` -| v1 call | v2 call | -| ------------------------------------------------------------ | ---------------------------------- | -| `client.request(req, ResultSchema)` | `client.request(req)` | -| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | -| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | -| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | -| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | -| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +| v1 call | v2 call | +| ------------------------------------------------------------ | ---------------------------------------------- | +| `client.request(req, ResultSchema)` | unchanged (schema optional), or `client.request(req)` | +| `client.request(req, ResultSchema, options)` | unchanged, or `client.request(req, options)` | +| `ctx.mcpReq.send(req, ResultSchema)` | unchanged (schema optional), or `ctx.mcpReq.send(req)` | +| `ctx.mcpReq.send(req, ResultSchema, options)` | unchanged, or `ctx.mcpReq.send(req, options)` | +| `client.callTool(params, CompatibilityCallToolResultSchema)` | unchanged (schema ignored), or `client.callTool(params)` | +| `client.callTool(params, schema, options)` | unchanged, or `client.callTool(params, options)` | -Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. +For spec methods you can drop now-unused schema imports (`CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc.) when they were only used in `request()`/`send()`/`callTool()` calls. If `CallToolResultSchema` was used for **runtime validation** (not just as a `request()` argument), replace with the `isCallToolResult` type guard: diff --git a/docs/migration.md b/docs/migration.md index 7cb7d58f6..7be893290 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -324,7 +324,7 @@ Note: the v2 signature takes a plain `string[]` instead of an options object. ### `setRequestHandler` and `setNotificationHandler` use method strings -The low-level `setRequestHandler` and `setNotificationHandler` methods on `Client`, `Server`, and `Protocol` now take a method string instead of a Zod schema. +The low-level `setRequestHandler` and `setNotificationHandler` methods on `Client`, `Server`, and `Protocol` now accept a method string in addition to the v1 Zod-schema form (both are supported). **Before (v1):** @@ -382,10 +382,25 @@ Common method string replacements: | `ResourceListChangedNotificationSchema` | `'notifications/resources/list_changed'` | | `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | -### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter +### Custom (non-standard) protocol methods -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas -like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +Vendor-specific methods are registered directly on `Client` or `Server` using the same Zod-schema form as v1: `setRequestHandler(zodSchemaWithMethodLiteral, handler)`. `request({ method, params }, ResultSchema)` and `notification({ method, params })` are unchanged from v1. + +```typescript +import { Server } from '@modelcontextprotocol/server'; + +const server = new Server({ name: 'app', version: '1.0.0' }, { capabilities: {} }); + +server.setRequestHandler(SearchRequestSchema, req => ({ hits: [req.params.query] })); + +// Calling from a Client — unchanged from v1: +const result = await client.request({ method: 'acme/search', params: { query: 'x' } }, SearchResult); +``` + +### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` schema parameter is now optional + +The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods still accept a result schema argument, but for spec methods it is optional — the SDK resolves the correct schema internally from the method name. You no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making spec-method requests. The schema argument remains the supported call form for custom (non-spec) methods. **`client.request()` — Before (v1):** @@ -888,7 +903,7 @@ import { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server/valida The following APIs are unchanged between v1 and v2 (only the import paths changed): -- `Client` constructor and most client methods (`connect`, `listTools`, `listPrompts`, `listResources`, `readResource`, etc.) — note: `callTool()` signature changed (schema parameter removed) +- `Client` constructor and most client methods (`connect`, `listTools`, `listPrompts`, `listResources`, `readResource`, etc.) — note: `callTool()` schema parameter is now optional - `McpServer` constructor, `server.connect(transport)`, `server.close()` - `Server` (low-level) constructor and all methods - `StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport` constructors and options diff --git a/examples/client/README.md b/examples/client/README.md index 12a2b0d68..d20984e2f 100644 --- a/examples/client/README.md +++ b/examples/client/README.md @@ -24,18 +24,19 @@ Most clients expect a server to be running. Start one from [`../server/README.md ## Example index -| Scenario | Description | File | -| --------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | -| Interactive Streamable HTTP client | CLI client that exercises tools/resources/prompts, notifications, elicitation, and tasks. | [`src/simpleStreamableHttp.ts`](src/simpleStreamableHttp.ts) | -| Backwards-compatible client (Streamable HTTP → SSE) | Tries Streamable HTTP first, falls back to legacy SSE on 4xx responses. | [`src/streamableHttpWithSseFallbackClient.ts`](src/streamableHttpWithSseFallbackClient.ts) | -| SSE polling client (legacy) | Polls a legacy HTTP+SSE server and demonstrates notification handling. | [`src/ssePollingClient.ts`](src/ssePollingClient.ts) | -| Parallel tool calls | Runs multiple tool calls in parallel. | [`src/parallelToolCallsClient.ts`](src/parallelToolCallsClient.ts) | -| Multiple clients in parallel | Connects multiple clients concurrently to the same server. | [`src/multipleClientsParallel.ts`](src/multipleClientsParallel.ts) | -| OAuth client (interactive) | OAuth-enabled client (dynamic registration, auth flow). | [`src/simpleOAuthClient.ts`](src/simpleOAuthClient.ts) | -| OAuth provider helper | Demonstrates reusable OAuth providers. | [`src/simpleOAuthClientProvider.ts`](src/simpleOAuthClientProvider.ts) | -| Client credentials (M2M) | Machine-to-machine OAuth client credentials example. | [`src/simpleClientCredentials.ts`](src/simpleClientCredentials.ts) | -| URL elicitation client | Drives URL-mode elicitation flows (sensitive input in a browser). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | -| Task interactive client | Demonstrates task-based execution + interactive server→client requests. | [`src/simpleTaskInteractiveClient.ts`](src/simpleTaskInteractiveClient.ts) | +| Scenario | Description | File | +| --------------------------------------------------- | --------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | +| Interactive Streamable HTTP client | CLI client that exercises tools/resources/prompts, notifications, elicitation, and tasks. | [`src/simpleStreamableHttp.ts`](src/simpleStreamableHttp.ts) | +| Backwards-compatible client (Streamable HTTP → SSE) | Tries Streamable HTTP first, falls back to legacy SSE on 4xx responses. | [`src/streamableHttpWithSseFallbackClient.ts`](src/streamableHttpWithSseFallbackClient.ts) | +| SSE polling client (legacy) | Polls a legacy HTTP+SSE server and demonstrates notification handling. | [`src/ssePollingClient.ts`](src/ssePollingClient.ts) | +| Parallel tool calls | Runs multiple tool calls in parallel. | [`src/parallelToolCallsClient.ts`](src/parallelToolCallsClient.ts) | +| Multiple clients in parallel | Connects multiple clients concurrently to the same server. | [`src/multipleClientsParallel.ts`](src/multipleClientsParallel.ts) | +| OAuth client (interactive) | OAuth-enabled client (dynamic registration, auth flow). | [`src/simpleOAuthClient.ts`](src/simpleOAuthClient.ts) | +| OAuth provider helper | Demonstrates reusable OAuth providers. | [`src/simpleOAuthClientProvider.ts`](src/simpleOAuthClientProvider.ts) | +| Client credentials (M2M) | Machine-to-machine OAuth client credentials example. | [`src/simpleClientCredentials.ts`](src/simpleClientCredentials.ts) | +| URL elicitation client | Drives URL-mode elicitation flows (sensitive input in a browser). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | +| Task interactive client | Demonstrates task-based execution + interactive server→client requests. | [`src/simpleTaskInteractiveClient.ts`](src/simpleTaskInteractiveClient.ts) | +| Custom (non-standard) methods client | Sends `acme/*` custom requests + notifications and handles custom progress notifications from the server. | [`src/customMethodExample.ts`](src/customMethodExample.ts) | ## URL elicitation example (server + client) diff --git a/examples/client/src/customMethodExample.ts b/examples/client/src/customMethodExample.ts new file mode 100644 index 000000000..d0ce0e994 --- /dev/null +++ b/examples/client/src/customMethodExample.ts @@ -0,0 +1,36 @@ +#!/usr/bin/env node +/** + * Calling vendor-specific (non-spec) JSON-RPC methods from a `Client`. + * + * - Send a custom request: `client.request({ method, params }, resultSchema)` + * - Send a custom notification: `client.notification({ method, params })` + * - Receive a custom notification: `client.setNotificationHandler(ZodSchemaWithMethodLiteral, handler)` + * + * Pair with the server in examples/server/src/customMethodExample.ts. + */ + +import { Client, StdioClientTransport } from '@modelcontextprotocol/client'; +import { z } from 'zod'; + +const SearchResult = z.object({ hits: z.array(z.string()) }); + +const ProgressNotification = z.object({ + method: z.literal('acme/searchProgress'), + params: z.object({ stage: z.string(), pct: z.number() }) +}); + +const client = new Client({ name: 'custom-method-client', version: '1.0.0' }, { capabilities: {} }); + +client.setNotificationHandler(ProgressNotification, n => { + console.log(`[client] progress: ${n.params.stage} ${n.params.pct}%`); +}); + +await client.connect(new StdioClientTransport({ command: 'npx', args: ['tsx', '../server/src/customMethodExample.ts'] })); + +const r = await client.request({ method: 'acme/search', params: { query: 'widgets' } }, SearchResult); +console.log('[client] hits=' + JSON.stringify(r.hits)); + +await client.notification({ method: 'acme/tick', params: { n: 1 } }); +await client.notification({ method: 'acme/tick', params: { n: 2 } }); + +await client.close(); diff --git a/examples/server/README.md b/examples/server/README.md index 384e4f2c2..cf0d9313c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -25,19 +25,20 @@ pnpm tsx src/simpleStreamableHttp.ts ## Example index -| Scenario | Description | File | -| ----------------------------------------- | ----------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | -| Streamable HTTP server (stateful) | Feature-rich server with tools/resources/prompts, logging, tasks, sampling, and optional OAuth. | [`src/simpleStreamableHttp.ts`](src/simpleStreamableHttp.ts) | -| Streamable HTTP server (stateless) | No session tracking; good for simple API-style servers. | [`src/simpleStatelessStreamableHttp.ts`](src/simpleStatelessStreamableHttp.ts) | -| JSON response mode (no SSE) | Streamable HTTP with JSON-only responses and limited notifications. | [`src/jsonResponseStreamableHttp.ts`](src/jsonResponseStreamableHttp.ts) | -| Server notifications over Streamable HTTP | Demonstrates server-initiated notifications via GET+SSE. | [`src/standaloneSseWithGetStreamableHttp.ts`](src/standaloneSseWithGetStreamableHttp.ts) | -| Output schema server | Demonstrates tool output validation with structured output schemas. | [`src/mcpServerOutputSchema.ts`](src/mcpServerOutputSchema.ts) | -| Form elicitation server | Collects **non-sensitive** user input via schema-driven forms. | [`src/elicitationFormExample.ts`](src/elicitationFormExample.ts) | -| URL elicitation server | Secure browser-based flows for **sensitive** input (API keys, OAuth, payments). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | -| Sampling + tasks server | Demonstrates sampling and experimental task-based execution. | [`src/toolWithSampleServer.ts`](src/toolWithSampleServer.ts) | -| Task interactive server | Task-based execution with interactive server→client requests. | [`src/simpleTaskInteractive.ts`](src/simpleTaskInteractive.ts) | -| Hono Streamable HTTP server | Streamable HTTP server built with Hono instead of Express. | [`src/honoWebStandardStreamableHttp.ts`](src/honoWebStandardStreamableHttp.ts) | -| SSE polling demo server | Legacy SSE server intended for polling demos. | [`src/ssePollingExample.ts`](src/ssePollingExample.ts) | +| Scenario | Description | File | +| ----------------------------------------- | -------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | +| Streamable HTTP server (stateful) | Feature-rich server with tools/resources/prompts, logging, tasks, sampling, and optional OAuth. | [`src/simpleStreamableHttp.ts`](src/simpleStreamableHttp.ts) | +| Streamable HTTP server (stateless) | No session tracking; good for simple API-style servers. | [`src/simpleStatelessStreamableHttp.ts`](src/simpleStatelessStreamableHttp.ts) | +| JSON response mode (no SSE) | Streamable HTTP with JSON-only responses and limited notifications. | [`src/jsonResponseStreamableHttp.ts`](src/jsonResponseStreamableHttp.ts) | +| Server notifications over Streamable HTTP | Demonstrates server-initiated notifications via GET+SSE. | [`src/standaloneSseWithGetStreamableHttp.ts`](src/standaloneSseWithGetStreamableHttp.ts) | +| Output schema server | Demonstrates tool output validation with structured output schemas. | [`src/mcpServerOutputSchema.ts`](src/mcpServerOutputSchema.ts) | +| Form elicitation server | Collects **non-sensitive** user input via schema-driven forms. | [`src/elicitationFormExample.ts`](src/elicitationFormExample.ts) | +| URL elicitation server | Secure browser-based flows for **sensitive** input (API keys, OAuth, payments). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | +| Sampling + tasks server | Demonstrates sampling and experimental task-based execution. | [`src/toolWithSampleServer.ts`](src/toolWithSampleServer.ts) | +| Task interactive server | Task-based execution with interactive server→client requests. | [`src/simpleTaskInteractive.ts`](src/simpleTaskInteractive.ts) | +| Hono Streamable HTTP server | Streamable HTTP server built with Hono instead of Express. | [`src/honoWebStandardStreamableHttp.ts`](src/honoWebStandardStreamableHttp.ts) | +| SSE polling demo server | Legacy SSE server intended for polling demos. | [`src/ssePollingExample.ts`](src/ssePollingExample.ts) | +| Custom (non-standard) methods server | Registers `acme/*` custom request + notification handlers and emits custom progress notifications. | [`src/customMethodExample.ts`](src/customMethodExample.ts) | ## OAuth demo flags (Streamable HTTP server) diff --git a/examples/server/src/customMethodExample.ts b/examples/server/src/customMethodExample.ts new file mode 100644 index 000000000..b8b2e222f --- /dev/null +++ b/examples/server/src/customMethodExample.ts @@ -0,0 +1,42 @@ +#!/usr/bin/env node +/** + * Registering vendor-specific (non-spec) JSON-RPC methods on a `Server`. + * + * Custom methods use the Zod-schema form of `setRequestHandler` / `setNotificationHandler`: + * pass a Zod object schema whose `method` field is `z.literal('')`. The same overload + * is available on `Client` (for server→client custom methods). + * + * To call these from the client side, use: + * await client.request({ method: 'acme/search', params: { query: 'widgets' } }, SearchResult) + * await client.notification({ method: 'acme/tick', params: { n: 1 } }) + * See examples/client/src/customMethodExample.ts. + */ + +import { Server, StdioServerTransport } from '@modelcontextprotocol/server'; +import { z } from 'zod'; + +const SearchRequest = z.object({ + method: z.literal('acme/search'), + params: z.object({ query: z.string() }) +}); + +const TickNotification = z.object({ + method: z.literal('acme/tick'), + params: z.object({ n: z.number() }) +}); + +const server = new Server({ name: 'custom-method-server', version: '1.0.0' }, { capabilities: {} }); + +server.setRequestHandler(SearchRequest, async (request, ctx) => { + console.error('[server] acme/search query=' + request.params.query); + await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'start', pct: 0 } }); + const hits = [request.params.query, request.params.query + '-result']; + await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'done', pct: 100 } }); + return { hits }; +}); + +server.setNotificationHandler(TickNotification, n => { + console.error('[server] acme/tick n=' + n.params.n); +}); + +await server.connect(new StdioServerTransport()); diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 21a43bd15..e1ae948d7 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -2,6 +2,7 @@ import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/client/_shims' import type { BaseContext, CallToolRequest, + CallToolResult, ClientCapabilities, ClientContext, ClientNotification, @@ -24,16 +25,19 @@ import type { NotificationMethod, ProtocolOptions, ReadResourceRequest, + Request, RequestMethod, RequestOptions, RequestTypeMap, + Result, ResultTypeMap, ServerCapabilities, SubscribeRequest, TaskManagerOptions, Tool, Transport, - UnsubscribeRequest + UnsubscribeRequest, + ZodLikeRequestSchema } from '@modelcontextprotocol/core'; import { assertClientRequestTaskCapability, @@ -47,9 +51,11 @@ import { ElicitRequestSchema, ElicitResultSchema, EmptyResultSchema, + extractMethodLiteral, extractTaskManagerOptions, GetPromptResultSchema, InitializeResultSchema, + isZodLikeSchema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, ListPromptsResultSchema, @@ -336,9 +342,26 @@ export class Client extends Protocol { public override setRequestHandler( method: M, handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise - ): void { + ): void; + /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + public override setRequestHandler( + requestSchema: T, + handler: (request: ReturnType, ctx: ClientContext) => Result | Promise + ): void; + public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void { + let method: string; + let handler: (request: Request, ctx: ClientContext) => ClientResult | Promise; + if (isZodLikeSchema(methodOrSchema)) { + const schema = methodOrSchema; + const userHandler = schemaHandler as (request: unknown, ctx: ClientContext) => Result | Promise; + method = extractMethodLiteral(schema); + handler = (req, ctx) => userHandler(schema.parse(req), ctx); + } else { + method = methodOrSchema; + handler = schemaHandler as (request: Request, ctx: ClientContext) => ClientResult | Promise; + } if (method === 'elicitation/create') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + const wrappedHandler = async (request: Request, ctx: ClientContext): Promise => { const validatedRequest = parseSchema(ElicitRequestSchema, request); if (!validatedRequest.success) { // Type guard: if success is false, error is guaranteed to exist @@ -404,11 +427,11 @@ export class Client extends Protocol { }; // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); + return this._setRequestHandlerByMethod(method, wrappedHandler); } if (method === 'sampling/createMessage') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + const wrappedHandler = async (request: Request, ctx: ClientContext): Promise => { const validatedRequest = parseSchema(CreateMessageRequestSchema, request); if (!validatedRequest.success) { const errorMessage = @@ -447,11 +470,11 @@ export class Client extends Protocol { }; // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); + return this._setRequestHandlerByMethod(method, wrappedHandler); } // Other handlers use default behavior - return super.setRequestHandler(method, handler); + return this._setRequestHandlerByMethod(method, handler); } protected assertCapability(capability: keyof ServerCapabilities, method: string): void { @@ -867,7 +890,20 @@ export class Client extends Protocol { * } * ``` */ - async callTool(params: CallToolRequest['params'], options?: RequestOptions) { + async callTool(params: CallToolRequest['params'], options?: RequestOptions): Promise; + /** The `resultSchema` argument is accepted for v1 source compatibility and ignored; output validation uses the tool's declared `outputSchema`. Prefer `callTool(params, options)`. */ + async callTool(params: CallToolRequest['params'], resultSchema: unknown, options?: RequestOptions): Promise; + async callTool( + params: CallToolRequest['params'], + optionsOrSchema?: RequestOptions | unknown, + maybeOptions?: RequestOptions + ): Promise { + const arg2IsSchema = optionsOrSchema != null && typeof optionsOrSchema === 'object' && 'parse' in optionsOrSchema; + // v1 allowed `callTool(params, undefined, opts)` (resultSchema was optional-with-default); + // when arg2 is not a schema, prefer arg3 if present so opts aren't dropped. + const options: RequestOptions | undefined = arg2IsSchema + ? maybeOptions + : (maybeOptions ?? (optionsOrSchema as RequestOptions | undefined)); // Guard: required-task tools need experimental API if (this.isToolTaskRequired(params.name)) { throw new ProtocolError( diff --git a/packages/client/src/validators/cfWorker.ts b/packages/client/src/validators/cfWorker.ts index b068e69a1..7d1c843e5 100644 --- a/packages/client/src/validators/cfWorker.ts +++ b/packages/client/src/validators/cfWorker.ts @@ -6,5 +6,5 @@ * import { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/client/validators/cf-worker'; * ``` */ -export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; export type { CfWorkerSchemaDraft } from '@modelcontextprotocol/core'; +export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; diff --git a/packages/client/test/client/callTool.compat.test.ts b/packages/client/test/client/callTool.compat.test.ts new file mode 100644 index 000000000..087298d82 --- /dev/null +++ b/packages/client/test/client/callTool.compat.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, it, vi } from 'vitest'; +import { Client } from '../../src/client/client.js'; + +describe('callTool v1-compat overload dispatch', () => { + function makeClient() { + const client = new Client({ name: 't', version: '1.0.0' }, { capabilities: {} }); + const spy = vi + .spyOn(client as unknown as { _requestWithSchema: (...a: unknown[]) => Promise }, '_requestWithSchema') + .mockResolvedValue({ content: [] }); + return { client, spy }; + } + + it('callTool(params, undefined, options) preserves options (v1: optional resultSchema)', async () => { + const { client, spy } = makeClient(); + const opts = { timeout: 5000 }; + await client.callTool({ name: 'x', arguments: {} }, undefined, opts); + expect(spy).toHaveBeenCalledTimes(1); + expect(spy.mock.calls[0]?.[2]).toBe(opts); + }); + + it('callTool(params, schema, options) preserves options', async () => { + const { client, spy } = makeClient(); + const opts = { timeout: 5000 }; + const schema = { parse: (x: unknown) => x }; + await client.callTool({ name: 'x', arguments: {} }, schema, opts); + expect(spy.mock.calls[0]?.[2]).toBe(opts); + }); + + it('callTool(params, options) — 2-arg form still works', async () => { + const { client, spy } = makeClient(); + const opts = { timeout: 5000 }; + await client.callTool({ name: 'x', arguments: {} }, opts); + expect(spy.mock.calls[0]?.[2]).toBe(opts); + }); + + it('callTool(params) — no options', async () => { + const { client, spy } = makeClient(); + await client.callTool({ name: 'x', arguments: {} }); + expect(spy.mock.calls[0]?.[2]).toBeUndefined(); + }); +}); diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 2dc1e13a8..227cd3e24 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -49,6 +49,7 @@ export type { ServerContext } from '../../shared/protocol.js'; export { DEFAULT_REQUEST_TIMEOUT_MSEC } from '../../shared/protocol.js'; +export type { ZodLikeRequestSchema } from '../../util/compatSchema.js'; // Task manager types (NOT TaskManager class itself — internal) export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from '../../shared/taskManager.js'; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index e707d9939..fb39c2868 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -48,4 +48,6 @@ export * from './validators/fromJsonSchema.js'; */ // Core types only - implementations are exported via separate entry points +export type { ZodLikeRequestSchema } from './util/compatSchema.js'; +export { extractMethodLiteral, isZodLikeSchema } from './util/compatSchema.js'; export type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, JsonSchemaValidatorResult } from './validators/types.js'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 57eab6932..2508cb51e 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -44,6 +44,8 @@ import { ProtocolErrorCode, SUPPORTED_PROTOCOL_VERSIONS } from '../types/index.js'; +import type { ZodLikeRequestSchema } from '../util/compatSchema.js'; +import { extractMethodLiteral, isZodLikeSchema } from '../util/compatSchema.js'; import type { AnySchema, SchemaOutput } from '../util/schema.js'; import { parseSchema } from '../util/schema.js'; import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; @@ -199,11 +201,21 @@ export type BaseContext = { * Sends a request that relates to the current request being handled. * * This is used by certain transports to correctly associate related messages. + * + * Two call forms (mirrors {@linkcode Protocol.request | request()}): + * - **Spec method** — `send({ method: 'sampling/createMessage', params }, options?)`. + * The result schema is resolved from the method name and the return is typed by it. + * - **With explicit result schema** — `send({ method, params }, resultSchema, options?)` + * for non-spec methods or custom result shapes. */ - send: ( - request: { method: M; params?: Record }, - options?: TaskRequestOptions - ) => Promise; + send: { + ( + request: { method: M; params?: Record }, + options?: TaskRequestOptions + ): Promise; + /** For spec methods the one-argument form is more concise; this overload is the supported call form for non-spec methods or custom result shapes. */ + (request: Request, resultSchema: T, options?: TaskRequestOptions): Promise>; + }; /** * Sends a notification that relates to the current request being handled. @@ -390,8 +402,10 @@ export abstract class Protocol { } /** - * Builds the context object for request handlers. Subclasses must override - * to return the appropriate context type (e.g., ServerContext adds HTTP request info). + * Builds the context object for request handlers. + * + * Subclasses implement this to enrich the {@linkcode BaseContext} (e.g. `Server` adds `http` + * and `mcpReq.log` to produce `ServerContext`). */ protected abstract buildContext(ctx: BaseContext, transportInfo?: MessageExtraInfo): ContextT; @@ -596,10 +610,13 @@ export abstract class Protocol { method: request.method, _meta: request.params?._meta, signal: abortController.signal, - send: (r: { method: M; params?: Record }, options?: TaskRequestOptions) => { - const resultSchema = getResultSchema(r.method); - return sendRequest(r as Request, resultSchema, options) as Promise; - }, + send: ((r: Request, optionsOrSchema?: TaskRequestOptions | AnySchema, maybeOptions?: TaskRequestOptions) => { + if (optionsOrSchema && '~standard' in optionsOrSchema) { + return sendRequest(r, optionsOrSchema, maybeOptions); + } + const resultSchema = getResultSchema(r.method as RequestMethod); + return sendRequest(r, resultSchema, optionsOrSchema); + }) as BaseContext['mcpReq']['send'], notify: sendNotification }, http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, @@ -757,33 +774,44 @@ export abstract class Protocol { protected abstract assertRequestHandlerCapability(method: string): void; /** - * A method to check if the remote side supports task creation for the given method. - * - * Called when sending a task-augmented outbound request (only when enforceStrictCapabilities is true). + * A method to check if a task creation is supported by the remote side, for the given method to be called. + * This is called by request when a task-augmented request is being sent and enforceStrictCapabilities is true. * This should be implemented by subclasses. */ protected abstract assertTaskCapability(method: string): void; /** - * A method to check if this side supports handling task creation for the given method. - * - * Called when receiving a task-augmented inbound request. + * A method to check if task creation is supported by the local side, for the given method to be handled. + * This is called when a task-augmented request is received. * This should be implemented by subclasses. */ protected abstract assertTaskHandlerCapability(method: string): void; /** - * Sends a request and waits for a response, resolving the result schema - * automatically from the method name. + * Sends a request and waits for a response. + * + * Two call forms: + * - **Spec method** — `request({ method: 'tools/call', params }, options?)`. The result schema + * is resolved automatically from the method name and the return type is `ResultTypeMap[M]`. + * - **With explicit result schema** — `request({ method, params }, resultSchema, options?)`. + * The result is validated against the supplied schema and typed by it. Use this for non-spec + * methods, or to supply a custom result shape for a spec method. * - * Do not use this method to emit notifications! Use {@linkcode Protocol.notification | notification()} instead. + * Do not use this method to emit notifications! Use + * {@linkcode Protocol.notification | notification()} instead. */ request( request: { method: M; params?: Record }, options?: RequestOptions - ): Promise { - const resultSchema = getResultSchema(request.method); - return this._requestWithSchema(request as Request, resultSchema, options) as Promise; + ): Promise; + /** For spec methods the one-argument form is more concise; this overload is the supported call form for non-spec methods or custom result shapes. */ + request(request: Request, resultSchema: T, options?: RequestOptions): Promise>; + request(request: Request, optionsOrSchema?: RequestOptions | AnySchema, maybeOptions?: RequestOptions): Promise { + if (optionsOrSchema && '~standard' in optionsOrSchema) { + return this._requestWithSchema(request, optionsOrSchema, maybeOptions); + } + const schema = getResultSchema(request.method as RequestMethod); + return this._requestWithSchema(request, schema, optionsOrSchema); } /** @@ -1001,19 +1029,47 @@ export abstract class Protocol { } /** - * Registers a handler to invoke when this protocol object receives a request with the given method. + * Registers a handler to invoke when this protocol object receives a request with the given + * method. Replaces any previous handler for the same method. * - * Note that this will replace any previous request handler for the same method. + * Call forms: + * - **Spec method** — `setRequestHandler('tools/call', (request, ctx) => …)`. + * The full `RequestTypeMap[M]` request object is validated by the SDK and passed to the + * handler. This is the form `Client`/`Server` use and override. + * - **Zod schema** — `setRequestHandler(RequestZodSchema, (request, ctx) => …)`. The method + * name is read from the schema's `method` literal; the handler receives the parsed request. */ setRequestHandler( method: M, handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise - ): void { - this.assertRequestHandlerCapability(method); - const schema = getRequestSchema(method); + ): void; + /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + setRequestHandler( + requestSchema: T, + handler: (request: ReturnType, ctx: ContextT) => Result | Promise + ): void; + setRequestHandler(method: string | ZodLikeRequestSchema, handler: (request: Request, ctx: ContextT) => Result | Promise): void { + if (isZodLikeSchema(method)) { + const requestSchema = method; + const methodStr = extractMethodLiteral(requestSchema); + this.assertRequestHandlerCapability(methodStr); + this._requestHandlers.set(methodStr, (request, ctx) => + Promise.resolve((handler as (req: unknown, ctx: ContextT) => Result | Promise)(requestSchema.parse(request), ctx)) + ); + return; + } + this._setRequestHandlerByMethod(method, handler); + } + /** + * Registers a request handler by method string, bypassing the public overload set. + * Used by `Client`/`Server` overrides to forward without `as RequestMethod` casts. + */ + protected _setRequestHandlerByMethod(method: string, handler: (request: Request, ctx: ContextT) => Result | Promise): void { + this.assertRequestHandlerCapability(method); + const schema = getRequestSchema(method as RequestMethod); this._requestHandlers.set(method, (request, ctx) => { - const parsed = schema.parse(request) as RequestTypeMap[M]; + const parsed = schema ? (schema.parse(request) as Request) : request; return Promise.resolve(handler(parsed, ctx)); }); } @@ -1021,32 +1077,47 @@ export abstract class Protocol { /** * Removes the request handler for the given method. */ - removeRequestHandler(method: RequestMethod): void { + removeRequestHandler(method: string): void { this._requestHandlers.delete(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ - assertCanSetRequestHandler(method: RequestMethod): void { + assertCanSetRequestHandler(method: string): void { if (this._requestHandlers.has(method)) { throw new Error(`A request handler for ${method} already exists, which would be overridden`); } } /** - * Registers a handler to invoke when this protocol object receives a notification with the given method. + * Registers a handler to invoke when this protocol object receives a notification with the + * given method. Replaces any previous handler for the same method. * - * Note that this will replace any previous notification handler for the same method. + * Mirrors {@linkcode setRequestHandler}: a spec-method form (handler receives the full + * notification object) and a Zod-schema form (method read from the schema's `method` literal). */ setNotificationHandler( method: M, handler: (notification: NotificationTypeMap[M]) => void | Promise - ): void { - const schema = getNotificationSchema(method); - + ): void; + /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + setNotificationHandler( + notificationSchema: T, + handler: (notification: ReturnType) => void | Promise + ): void; + setNotificationHandler(method: string | ZodLikeRequestSchema, handler: (notification: Notification) => void | Promise): void { + if (isZodLikeSchema(method)) { + const notificationSchema = method; + const methodStr = extractMethodLiteral(notificationSchema); + this._notificationHandlers.set(methodStr, n => + Promise.resolve((handler as (n: unknown) => void | Promise)(notificationSchema.parse(n))) + ); + return; + } + const schema = getNotificationSchema(method as NotificationMethod); this._notificationHandlers.set(method, notification => { - const parsed = schema.parse(notification); + const parsed = schema ? schema.parse(notification) : notification; return Promise.resolve(handler(parsed)); }); } @@ -1054,7 +1125,7 @@ export abstract class Protocol { /** * Removes the notification handler for the given method. */ - removeNotificationHandler(method: NotificationMethod): void { + removeNotificationHandler(method: string): void { this._notificationHandlers.delete(method); } } diff --git a/packages/core/src/util/compatSchema.ts b/packages/core/src/util/compatSchema.ts new file mode 100644 index 000000000..63956c97b --- /dev/null +++ b/packages/core/src/util/compatSchema.ts @@ -0,0 +1,41 @@ +/** + * Helpers for the Zod-schema form of `setRequestHandler` / `setNotificationHandler`. + * + * v1 accepted a Zod object whose `.shape.method` is `z.literal('')`. + * v2 also accepts the method string directly. These helpers detect the schema + * form and extract the literal so the dispatcher can route to the correct path. + * + * @internal + */ + +/** + * Minimal structural type for a Zod object schema. The `method` literal is + * checked at runtime by `extractMethodLiteral`; the type-level constraint + * is intentionally loose because zod v4's `ZodLiteral` doesn't surface `.value` + * in its declared type (only at runtime). + */ +export interface ZodLikeRequestSchema { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + shape: any; + parse(input: unknown): unknown; +} + +/** True if `arg` looks like a Zod object schema (has `.shape` and `.parse`). */ +export function isZodLikeSchema(arg: unknown): arg is ZodLikeRequestSchema { + return typeof arg === 'object' && arg !== null && 'shape' in arg && typeof (arg as { parse?: unknown }).parse === 'function'; +} + +/** + * Extracts the string value from a Zod-like schema's `shape.method` literal. + * Throws if no string `method` literal is present. + */ +export function extractMethodLiteral(schema: ZodLikeRequestSchema): string { + const methodField = (schema.shape as Record | undefined)?.method as + | { value?: unknown; def?: { values?: unknown[] } } + | undefined; + const value = methodField?.value ?? methodField?.def?.values?.[0]; + if (typeof value !== 'string') { + throw new TypeError('Schema passed to setRequestHandler/setNotificationHandler is missing a string `method` literal'); + } + return value; +} diff --git a/packages/core/test/shared/customMethods.test.ts b/packages/core/test/shared/customMethods.test.ts new file mode 100644 index 000000000..0868b5805 --- /dev/null +++ b/packages/core/test/shared/customMethods.test.ts @@ -0,0 +1,143 @@ +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +import type { BaseContext } from '../../src/shared/protocol.js'; +import { Protocol } from '../../src/shared/protocol.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +// Minimal concrete Protocol for tests; capability checks are no-ops. +class TestProtocol extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } +} + +async function makePair() { + const [t1, t2] = InMemoryTransport.createLinkedPair(); + const a = new TestProtocol(); + const b = new TestProtocol(); + await a.connect(t1); + await b.connect(t2); + return { a, b }; +} + +const EchoRequest = z.object({ method: z.literal('acme/echo'), params: z.object({ msg: z.string() }) }); +const TickNotification = z.object({ method: z.literal('acme/tick'), params: z.object({ n: z.number() }) }); + +describe('setRequestHandler — Zod-schema form', () => { + it('round-trips a custom request via Zod schema', async () => { + const { a, b } = await makePair(); + b.setRequestHandler(EchoRequest, req => ({ reply: req.params.msg.toUpperCase() })); + const result = await a.request({ method: 'acme/echo', params: { msg: 'hi' } }, z.object({ reply: z.string() })); + expect(result).toEqual({ reply: 'HI' }); + }); + + it('rejects invalid params via the Zod schema', async () => { + const { a, b } = await makePair(); + b.setRequestHandler(EchoRequest, req => ({ reply: req.params.msg })); + await expect(a.request({ method: 'acme/echo', params: { msg: 42 } }, z.object({ reply: z.string() }))).rejects.toThrow(); + }); + + it('removeRequestHandler works for any method string', async () => { + const { a, b } = await makePair(); + b.setRequestHandler(EchoRequest, req => ({ reply: req.params.msg })); + await expect(a.request({ method: 'acme/echo', params: { msg: 'x' } }, z.object({ reply: z.string() }))).resolves.toEqual({ + reply: 'x' + }); + b.removeRequestHandler('acme/echo'); + await expect(a.request({ method: 'acme/echo', params: { msg: 'x' } }, z.object({ reply: z.string() }))).rejects.toThrow( + /Method not found/ + ); + }); + + it('two-arg spec-method form still works', async () => { + const { a, b } = await makePair(); + let pinged = false; + b.setRequestHandler('ping', () => { + pinged = true; + return {}; + }); + await a.request({ method: 'ping' }); + expect(pinged).toBe(true); + }); +}); + +describe('setNotificationHandler — Zod-schema form', () => { + it('receives a custom notification via Zod schema', async () => { + const { a, b } = await makePair(); + const received: unknown[] = []; + b.setNotificationHandler(TickNotification, n => { + received.push(n.params); + }); + await a.notification({ method: 'acme/tick', params: { n: 1 } }); + await a.notification({ method: 'acme/tick', params: { n: 2 } }); + await new Promise(r => setTimeout(r, 0)); + expect(received).toEqual([{ n: 1 }, { n: 2 }]); + }); + + it('two-arg spec-method form still works', async () => { + const { a, b } = await makePair(); + let got = false; + b.setNotificationHandler('notifications/initialized', () => { + got = true; + }); + await a.notification({ method: 'notifications/initialized' }); + await new Promise(r => setTimeout(r, 0)); + expect(got).toBe(true); + }); +}); + +describe('request() — explicit result schema overload', () => { + it('uses the supplied result schema for a non-spec method', async () => { + const { a, b } = await makePair(); + b.setRequestHandler(EchoRequest, req => ({ reply: req.params.msg })); + const r = await a.request({ method: 'acme/echo', params: { msg: 'ok' } }, z.object({ reply: z.string() })); + expect(r.reply).toBe('ok'); + }); + + it('spec method without schema uses method-keyed return type', async () => { + const { a, b } = await makePair(); + b.setRequestHandler('ping', () => ({})); + const r = await a.request({ method: 'ping' }); + expect(r).toEqual({}); + }); +}); + +describe('ctx.mcpReq.send() — explicit result schema overload', () => { + it('forwards to a related request and validates result via the supplied schema', async () => { + const { a, b } = await makePair(); + a.setRequestHandler(EchoRequest, req => ({ reply: req.params.msg })); + let captured: unknown; + b.setRequestHandler(z.object({ method: z.literal('acme/outer') }), async (_req, ctx) => { + captured = await ctx.mcpReq.send({ method: 'acme/echo', params: { msg: 'via-send' } }, z.object({ reply: z.string() })); + return {}; + }); + await a.request({ method: 'acme/outer' }, z.object({})); + expect(captured).toEqual({ reply: 'via-send' }); + }); + + it('spec-method form (no schema) uses method-keyed return', async () => { + const { a, b } = await makePair(); + a.setRequestHandler('ping', () => ({})); + let pingResult: unknown; + b.setRequestHandler(z.object({ method: z.literal('acme/outer') }), async (_req, ctx) => { + pingResult = await ctx.mcpReq.send({ method: 'ping' }); + return {}; + }); + await a.request({ method: 'acme/outer' }, z.object({})); + expect(pingResult).toEqual({}); + }); +}); + +describe('notification() mock-assignability', () => { + it('single-signature notification() is assignable from a simple mock (compile-time check)', () => { + const p = new TestProtocol(); + p.notification = async (_n: { method: string }) => {}; + expect(typeof p.notification).toBe('function'); + }); +}); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 4361f3e1e..a4690b7f4 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -21,17 +21,20 @@ import type { NotificationMethod, NotificationOptions, ProtocolOptions, + Request, RequestMethod, RequestOptions, RequestTypeMap, ResourceUpdatedNotification, + Result, ResultTypeMap, ServerCapabilities, ServerContext, ServerResult, TaskManagerOptions, ToolResultContent, - ToolUseContent + ToolUseContent, + ZodLikeRequestSchema } from '@modelcontextprotocol/core'; import { assertClientRequestTaskCapability, @@ -43,7 +46,9 @@ import { CreateTaskResultSchema, ElicitResultSchema, EmptyResultSchema, + extractMethodLiteral, extractTaskManagerOptions, + isZodLikeSchema, LATEST_PROTOCOL_VERSION, ListRootsResultSchema, LoggingLevelSchema, @@ -225,9 +230,26 @@ export class Server extends Protocol { public override setRequestHandler( method: M, handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise - ): void { + ): void; + /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + public override setRequestHandler( + requestSchema: T, + handler: (request: ReturnType, ctx: ServerContext) => Result | Promise + ): void; + public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void { + let method: string; + let handler: (request: Request, ctx: ServerContext) => ServerResult | Promise; + if (isZodLikeSchema(methodOrSchema)) { + const schema = methodOrSchema; + const userHandler = schemaHandler as (request: unknown, ctx: ServerContext) => Result | Promise; + method = extractMethodLiteral(schema); + handler = (req, ctx) => userHandler(schema.parse(req), ctx); + } else { + method = methodOrSchema; + handler = schemaHandler as (request: Request, ctx: ServerContext) => ServerResult | Promise; + } if (method === 'tools/call') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise => { + const wrappedHandler = async (request: Request, ctx: ServerContext): Promise => { const validatedRequest = parseSchema(CallToolRequestSchema, request); if (!validatedRequest.success) { const errorMessage = @@ -264,11 +286,11 @@ export class Server extends Protocol { }; // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); + return this._setRequestHandlerByMethod(method, wrappedHandler); } // Other handlers use default behavior - return super.setRequestHandler(method, handler); + return this._setRequestHandlerByMethod(method, handler); } protected assertCapabilityForMethod(method: RequestMethod): void { diff --git a/packages/server/src/validators/cfWorker.ts b/packages/server/src/validators/cfWorker.ts index 9a3a88405..e04436dbd 100644 --- a/packages/server/src/validators/cfWorker.ts +++ b/packages/server/src/validators/cfWorker.ts @@ -6,5 +6,5 @@ * import { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server/validators/cf-worker'; * ``` */ -export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; export type { CfWorkerSchemaDraft } from '@modelcontextprotocol/core'; +export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; diff --git a/packages/server/test/server/setRequestHandlerSchemaParity.test.ts b/packages/server/test/server/setRequestHandlerSchemaParity.test.ts new file mode 100644 index 000000000..313cd0e8e --- /dev/null +++ b/packages/server/test/server/setRequestHandlerSchemaParity.test.ts @@ -0,0 +1,66 @@ +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +import { CallToolRequestSchema, InMemoryTransport } from '@modelcontextprotocol/core'; + +import { Server } from '../../src/server/server.js'; + +/** + * Regression test: setRequestHandler(CallToolRequestSchema, h) and + * setRequestHandler('tools/call', h) must apply the same per-method + * wrapping (task-result validation when params.task is set). + */ +describe('Server.setRequestHandler — Zod-schema form parity', () => { + async function setup(register: (s: Server) => void) { + const server = new Server( + { name: 't', version: '1.0' }, + { capabilities: { tools: {}, tasks: { requests: { tools: { call: {} } } } } } + ); + register(server); + const [ct, st] = InMemoryTransport.createLinkedPair(); + await server.connect(st); + await ct.start(); + return { ct }; + } + + async function callToolWithTask(ct: InMemoryTransport): Promise<{ result?: unknown; error?: unknown }> { + return await new Promise(resolve => { + ct.onmessage = m => { + const msg = m as { result?: unknown; error?: unknown }; + if ('result' in msg || 'error' in msg) resolve(msg); + }; + ct.send({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'x', arguments: {}, task: { ttl: 1000 } } + }); + }); + } + + it('schema form gets the same task-result validation as string form', async () => { + const invalidTaskResult = { content: [{ type: 'text' as const, text: 'not a task result' }] }; + + const viaString = await setup(s => s.setRequestHandler('tools/call', () => invalidTaskResult)); + const viaSchema = await setup(s => s.setRequestHandler(CallToolRequestSchema, () => invalidTaskResult)); + + const stringRes = await callToolWithTask(viaString.ct); + const schemaRes = await callToolWithTask(viaSchema.ct); + + expect((stringRes.error as { message: string }).message).toContain('Invalid task creation result'); + expect(schemaRes.error).toEqual(stringRes.error); + }); + + it('schema form handles non-spec methods through Server (no spec-schema crash)', async () => { + const Echo = z.object({ method: z.literal('acme/echo'), params: z.object({ msg: z.string() }) }); + const { ct } = await setup(s => s.setRequestHandler(Echo, req => ({ reply: req.params.msg }))); + const res = await new Promise<{ result?: unknown; error?: unknown }>(resolve => { + ct.onmessage = m => { + const msg = m as { result?: unknown; error?: unknown }; + if ('result' in msg || 'error' in msg) resolve(msg); + }; + ct.send({ jsonrpc: '2.0', id: 1, method: 'acme/echo', params: { msg: 'hi' } }); + }); + expect(res.result).toEqual({ reply: 'hi' }); + }); +});