diff --git a/MssqlMcp/Node/src/index.ts b/MssqlMcp/Node/src/index.ts index 8dc1f30..2a8472d 100644 --- a/MssqlMcp/Node/src/index.ts +++ b/MssqlMcp/Node/src/index.ts @@ -3,8 +3,10 @@ // External imports import * as dotenv from "dotenv"; import sql from "mssql"; +import * as http from "http"; import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js"; import { CallToolRequestSchema, ListToolsRequestSchema, @@ -23,6 +25,7 @@ import { DescribeTableTool } from "./tools/DescribeTableTool.js"; // MSSQL Database connection configuration // const credential = new DefaultAzureCredential(); +dotenv.config(); // Globals for connection and token reuse let globalSqlPool: sql.ConnectionPool | null = null; @@ -31,30 +34,54 @@ let globalTokenExpiresOn: Date | null = null; // Function to create SQL config with fresh access token, returns token and expiry export async function createSqlConfig(): Promise<{ config: sql.config, token: string, expiresOn: Date }> { + const trustServerCertificate = process.env.TRUST_SERVER_CERTIFICATE?.toLowerCase() === 'true'; + const connectionTimeout = process.env.CONNECTION_TIMEOUT ? parseInt(process.env.CONNECTION_TIMEOUT, 10) : 30; + + // Base configuration shared by both auth methods + const baseConfig = { + server: process.env.SERVER_NAME!, + database: process.env.DATABASE_NAME!, + options: { + encrypt: true, + trustServerCertificate + }, + connectionTimeout: connectionTimeout * 1000, + }; + + // --------------------------------------------------------- + // PATH A: Standard SQL Authentication (via .env) + // --------------------------------------------------------- + if (process.env.SQL_USER && process.env.SQL_PASSWORD) { + return { + config: { + ...baseConfig, + user: process.env.SQL_USER, + password: process.env.SQL_PASSWORD, + }, + token: '', + expiresOn: new Date(Date.now() + 30 * 60 * 1000) + }; + } + + // --------------------------------------------------------- + // PATH B: Azure Active Directory Authentication (Browser Token) + // --------------------------------------------------------- const credential = new InteractiveBrowserCredential({ redirectUri: 'http://localhost' // disableAutomaticAuthentication : true }); + const accessToken = await credential.getToken('https://database.windows.net/.default'); - const trustServerCertificate = process.env.TRUST_SERVER_CERTIFICATE?.toLowerCase() === 'true'; - const connectionTimeout = process.env.CONNECTION_TIMEOUT ? parseInt(process.env.CONNECTION_TIMEOUT, 10) : 30; - return { config: { - server: process.env.SERVER_NAME!, - database: process.env.DATABASE_NAME!, - options: { - encrypt: true, - trustServerCertificate - }, + ...baseConfig, authentication: { type: 'azure-active-directory-access-token', options: { token: accessToken?.token!, }, }, - connectionTimeout: connectionTimeout * 1000, // convert seconds to milliseconds }, token: accessToken?.token!, expiresOn: accessToken?.expiresOnTimestamp ? new Date(accessToken.expiresOnTimestamp) : new Date(Date.now() + 30 * 60 * 1000) @@ -156,11 +183,6 @@ async function runServer() { } } -runServer().catch((error) => { - console.error("Fatal error running server:", error); - process.exit(1); -}); - // Connect to SQL only when handling a request async function ensureSqlConnection() { @@ -197,4 +219,112 @@ function wrapToolRun(tool: { run: (...args: any[]) => Promise }) { }; } -[insertDataTool, readDataTool, updateDataTool, createTableTool, createIndexTool, dropTableTool, listTableTool, describeTableTool].forEach(wrapToolRun); \ No newline at end of file +[insertDataTool, readDataTool, updateDataTool, createTableTool, createIndexTool, dropTableTool, listTableTool, describeTableTool].forEach(wrapToolRun); + +// ========================================== +// Transports & Server Startup +// ========================================== + +// HTTP / Server-Sent Events Transport +async function runHttpServer(port: number) { + + // Map to store active client sessions + const activeTransports = new Map(); + + const httpServer = http.createServer(async (req, res) => { + // Parse the URL to get the path and query parameters + const parsedUrl = new URL(req.url || "/", `http://${req.headers.host}`); + + // Add basic CORS headers for web clients + res.setHeader("Access-Control-Allow-Origin", "*"); + res.setHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); + res.setHeader("Access-Control-Allow-Headers", "Content-Type"); + + // Handle preflight requests + if (req.method === "OPTIONS") { + res.writeHead(200); + res.end(); + return; + } + + // --------------------------------------------------------- + // SSE Connection Endpoint (Handshake & Stream Start) + // --------------------------------------------------------- + if (req.method === "GET" && parsedUrl.pathname === "/sse") { + console.error("New client connecting..."); + + const transport = new SSEServerTransport("/messages", res as any); + + // Connect this specific transport to the server instance + await server.connect(transport); + + // Store the transport session so we can route POST messages to it later + activeTransports.set(transport.sessionId, transport); + console.error(`Client connected. Session ID: ${transport.sessionId}`); + + // Crucial: Clean up memory when the client disconnects + res.on("close", () => { + console.error(`Client disconnected. Cleaning up Session ID: ${transport.sessionId}`); + activeTransports.delete(transport.sessionId); + }); + return; + } + + // --------------------------------------------------------- + // Message Posting Endpoint (Receiving Client Tool Calls) + // --------------------------------------------------------- + if (req.method === "POST" && parsedUrl.pathname === "/messages") { + const sessionId = parsedUrl.searchParams.get("sessionId"); + + if (!sessionId) { + res.writeHead(400, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ error: "Missing sessionId in query parameters." })); + return; + } + + const transport = activeTransports.get(sessionId); + + if (transport) { + // Native Node http handles the stream internally in the SDK, + // no need to parse req.body like in Express + await transport.handlePostMessage(req as any, res as any); + } else { + res.writeHead(404, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ error: "Session not found or expired. Reconnect to /sse." })); + } + return; + } + + // 404 Fallback + res.writeHead(404); + res.end("Not Found"); + }); + + httpServer.listen(port, () => { + console.error(`Multi-Client MSSQL MCP Server running on HTTP/SSE.`); + console.error(`Connect to SSE stream at: http://localhost:${port}/sse`); + console.error(`Send POST messages to: http://localhost:${port}/messages?sessionId=...`); + }); +} + +// ========================================== +// Execution Logic +// ========================================== + +// Parse command line arguments +const args = process.argv.slice(2); +const isSSE = args.includes("--sse"); +const portArgIndex = args.indexOf("--port"); +const port = portArgIndex !== -1 ? parseInt(args[portArgIndex + 1], 10) : 3000; + +if (isSSE) { + runHttpServer(port).catch((error) => { + console.error("Fatal error running HTTP server:", error); + process.exit(1); + }); +} else { + runServer().catch((error) => { + console.error("Fatal error running stdio server:", error); + process.exit(1); + }); +} \ No newline at end of file