Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 146 additions & 16 deletions MssqlMcp/Node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
// External imports
import * as dotenv from "dotenv";
import sql from "mssql";
import * as http from "http";
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js";
import {
CallToolRequestSchema,
ListToolsRequestSchema,
Expand All @@ -23,6 +25,7 @@ import { DescribeTableTool } from "./tools/DescribeTableTool.js";

// MSSQL Database connection configuration
// const credential = new DefaultAzureCredential();
dotenv.config();

// Globals for connection and token reuse
let globalSqlPool: sql.ConnectionPool | null = null;
Expand All @@ -31,30 +34,54 @@ let globalTokenExpiresOn: Date | null = null;

// Function to create SQL config with fresh access token, returns token and expiry
export async function createSqlConfig(): Promise<{ config: sql.config, token: string, expiresOn: Date }> {
const trustServerCertificate = process.env.TRUST_SERVER_CERTIFICATE?.toLowerCase() === 'true';
const connectionTimeout = process.env.CONNECTION_TIMEOUT ? parseInt(process.env.CONNECTION_TIMEOUT, 10) : 30;

// Base configuration shared by both auth methods
const baseConfig = {
server: process.env.SERVER_NAME!,
database: process.env.DATABASE_NAME!,
options: {
encrypt: true,
trustServerCertificate
},
connectionTimeout: connectionTimeout * 1000,
};

// ---------------------------------------------------------
// PATH A: Standard SQL Authentication (via .env)
// ---------------------------------------------------------
if (process.env.SQL_USER && process.env.SQL_PASSWORD) {
return {
config: {
...baseConfig,
user: process.env.SQL_USER,
password: process.env.SQL_PASSWORD,
},
token: '',
expiresOn: new Date(Date.now() + 30 * 60 * 1000)
};
}

// ---------------------------------------------------------
// PATH B: Azure Active Directory Authentication (Browser Token)
// ---------------------------------------------------------
const credential = new InteractiveBrowserCredential({
redirectUri: 'http://localhost'
// disableAutomaticAuthentication : true
});

const accessToken = await credential.getToken('https://database.windows.net/.default');

const trustServerCertificate = process.env.TRUST_SERVER_CERTIFICATE?.toLowerCase() === 'true';
const connectionTimeout = process.env.CONNECTION_TIMEOUT ? parseInt(process.env.CONNECTION_TIMEOUT, 10) : 30;

return {
config: {
server: process.env.SERVER_NAME!,
database: process.env.DATABASE_NAME!,
options: {
encrypt: true,
trustServerCertificate
},
...baseConfig,
authentication: {
type: 'azure-active-directory-access-token',
options: {
token: accessToken?.token!,
},
},
connectionTimeout: connectionTimeout * 1000, // convert seconds to milliseconds
},
token: accessToken?.token!,
expiresOn: accessToken?.expiresOnTimestamp ? new Date(accessToken.expiresOnTimestamp) : new Date(Date.now() + 30 * 60 * 1000)
Expand Down Expand Up @@ -156,11 +183,6 @@ async function runServer() {
}
}

runServer().catch((error) => {
console.error("Fatal error running server:", error);
process.exit(1);
});

// Connect to SQL only when handling a request

async function ensureSqlConnection() {
Expand Down Expand Up @@ -197,4 +219,112 @@ function wrapToolRun(tool: { run: (...args: any[]) => Promise<any> }) {
};
}

[insertDataTool, readDataTool, updateDataTool, createTableTool, createIndexTool, dropTableTool, listTableTool, describeTableTool].forEach(wrapToolRun);
[insertDataTool, readDataTool, updateDataTool, createTableTool, createIndexTool, dropTableTool, listTableTool, describeTableTool].forEach(wrapToolRun);

// ==========================================
// Transports & Server Startup
// ==========================================

// HTTP / Server-Sent Events Transport
async function runHttpServer(port: number) {

// Map to store active client sessions
const activeTransports = new Map<string, SSEServerTransport>();

const httpServer = http.createServer(async (req, res) => {
// Parse the URL to get the path and query parameters
const parsedUrl = new URL(req.url || "/", `http://${req.headers.host}`);

// Add basic CORS headers for web clients
res.setHeader("Access-Control-Allow-Origin", "*");
res.setHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
res.setHeader("Access-Control-Allow-Headers", "Content-Type");

// Handle preflight requests
if (req.method === "OPTIONS") {
res.writeHead(200);
res.end();
return;
}

// ---------------------------------------------------------
// SSE Connection Endpoint (Handshake & Stream Start)
// ---------------------------------------------------------
if (req.method === "GET" && parsedUrl.pathname === "/sse") {
console.error("New client connecting...");

const transport = new SSEServerTransport("/messages", res as any);

// Connect this specific transport to the server instance
await server.connect(transport);

// Store the transport session so we can route POST messages to it later
activeTransports.set(transport.sessionId, transport);
console.error(`Client connected. Session ID: ${transport.sessionId}`);

// Crucial: Clean up memory when the client disconnects
res.on("close", () => {
console.error(`Client disconnected. Cleaning up Session ID: ${transport.sessionId}`);
activeTransports.delete(transport.sessionId);
});
return;
}

// ---------------------------------------------------------
// Message Posting Endpoint (Receiving Client Tool Calls)
// ---------------------------------------------------------
if (req.method === "POST" && parsedUrl.pathname === "/messages") {
const sessionId = parsedUrl.searchParams.get("sessionId");

if (!sessionId) {
res.writeHead(400, { "Content-Type": "application/json" });
res.end(JSON.stringify({ error: "Missing sessionId in query parameters." }));
return;
}

const transport = activeTransports.get(sessionId);

if (transport) {
// Native Node http handles the stream internally in the SDK,
// no need to parse req.body like in Express
await transport.handlePostMessage(req as any, res as any);
} else {
res.writeHead(404, { "Content-Type": "application/json" });
res.end(JSON.stringify({ error: "Session not found or expired. Reconnect to /sse." }));
}
return;
}

// 404 Fallback
res.writeHead(404);
res.end("Not Found");
});

httpServer.listen(port, () => {
console.error(`Multi-Client MSSQL MCP Server running on HTTP/SSE.`);
console.error(`Connect to SSE stream at: http://localhost:${port}/sse`);
console.error(`Send POST messages to: http://localhost:${port}/messages?sessionId=...`);
});
}

// ==========================================
// Execution Logic
// ==========================================

// Parse command line arguments
const args = process.argv.slice(2);
const isSSE = args.includes("--sse");
const portArgIndex = args.indexOf("--port");
const port = portArgIndex !== -1 ? parseInt(args[portArgIndex + 1], 10) : 3000;

if (isSSE) {
runHttpServer(port).catch((error) => {
console.error("Fatal error running HTTP server:", error);
process.exit(1);
});
} else {
runServer().catch((error) => {
console.error("Fatal error running stdio server:", error);
process.exit(1);
});
}