diff --git a/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql b/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql new file mode 100644 index 000000000..30e6d30f9 --- /dev/null +++ b/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql @@ -0,0 +1,41 @@ +-- CreateTable +CREATE TABLE "McpServer" ( + "id" TEXT NOT NULL, + "serverUrl" TEXT NOT NULL, + "clientInfo" TEXT, + "orgId" INTEGER NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "McpServer_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "UserMcpServer" ( + "userId" TEXT NOT NULL, + "serverId" TEXT NOT NULL, + "name" TEXT NOT NULL, + "tokens" TEXT, + "tokensExpiresAt" TIMESTAMP(3), + "codeVerifier" TEXT, + "state" TEXT, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "UserMcpServer_pkey" PRIMARY KEY ("userId","serverId") +); + +-- CreateIndex +CREATE UNIQUE INDEX "McpServer_serverUrl_orgId_key" ON "McpServer"("serverUrl", "orgId"); + +-- CreateIndex +CREATE INDEX "UserMcpServer_state_idx" ON "UserMcpServer"("state"); + +-- AddForeignKey +ALTER TABLE "McpServer" ADD CONSTRAINT "McpServer_orgId_fkey" FOREIGN KEY ("orgId") REFERENCES "Org"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_serverId_fkey" FOREIGN KEY ("serverId") REFERENCES "McpServer"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/db/prisma/migrations/20260524000000_org_approved_mcp_servers/migration.sql b/packages/db/prisma/migrations/20260524000000_org_approved_mcp_servers/migration.sql new file mode 100644 index 000000000..99d1bc446 --- /dev/null +++ b/packages/db/prisma/migrations/20260524000000_org_approved_mcp_servers/migration.sql @@ -0,0 +1,27 @@ +-- Add org-approved display/tool identity to shared MCP servers. +ALTER TABLE "McpServer" ADD COLUMN "name" TEXT; +ALTER TABLE "McpServer" ADD COLUMN "sanitizedName" TEXT; + +-- This branch has not shipped, but keep local development databases migratable. +UPDATE "McpServer" +SET "name" = COALESCE( + ( + SELECT "UserMcpServer"."name" + FROM "UserMcpServer" + WHERE "UserMcpServer"."serverId" = "McpServer"."id" + ORDER BY "UserMcpServer"."createdAt" ASC + LIMIT 1 + ), + "McpServer"."serverUrl" +); + +UPDATE "McpServer" +SET "sanitizedName" = regexp_replace(lower("name"), '[^a-z0-9]', '_', 'g'); + +ALTER TABLE "McpServer" ALTER COLUMN "name" SET NOT NULL; +ALTER TABLE "McpServer" ALTER COLUMN "sanitizedName" SET NOT NULL; + +-- Remove per-user display identity now that MCP servers are org-approved. +ALTER TABLE "UserMcpServer" DROP COLUMN "name"; + +CREATE UNIQUE INDEX "McpServer_orgId_sanitizedName_key" ON "McpServer"("orgId", "sanitizedName"); diff --git a/packages/db/prisma/migrations/20260525000000_add_user_mcp_server_server_id_index/migration.sql b/packages/db/prisma/migrations/20260525000000_add_user_mcp_server_server_id_index/migration.sql new file mode 100644 index 000000000..d171bca2c --- /dev/null +++ b/packages/db/prisma/migrations/20260525000000_add_user_mcp_server_server_id_index/migration.sql @@ -0,0 +1 @@ +CREATE INDEX "UserMcpServer_serverId_idx" ON "UserMcpServer"("serverId"); diff --git a/packages/db/prisma/migrations/20260526000000_add_mcp_server_client_info_source/migration.sql b/packages/db/prisma/migrations/20260526000000_add_mcp_server_client_info_source/migration.sql new file mode 100644 index 000000000..1f03e8968 --- /dev/null +++ b/packages/db/prisma/migrations/20260526000000_add_mcp_server_client_info_source/migration.sql @@ -0,0 +1,5 @@ +-- Track whether McpServer.clientInfo came from dynamic client registration or admin-provided static credentials. +CREATE TYPE "McpServerClientInfoSource" AS ENUM ('DYNAMIC', 'STATIC'); + +ALTER TABLE "McpServer" +ADD COLUMN "clientInfoSource" "McpServerClientInfoSource" NOT NULL DEFAULT 'DYNAMIC'; diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 7e1af6be7..38c4b1e5d 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -292,6 +292,8 @@ model Org { chats Chat[] + mcpServers McpServer[] + license License? /// Set the first time this instance is seen to be on a trial subscription. @@ -328,6 +330,11 @@ enum OrgRole { MEMBER } +enum McpServerClientInfoSource { + DYNAMIC + STATIC +} + model UserToOrg { joinedAt DateTime @default(now()) @@ -409,6 +416,8 @@ model User { /// claim baked into the JWT cookie at mint time. sessionVersion Int @default(0) + userMcpServers UserMcpServer[] + createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -603,3 +612,59 @@ model OAuthToken { createdAt DateTime @default(now()) lastUsedAt DateTime? } + +/// An external MCP server endpoint, unique per org. +/// Stores the dynamic client registration (client_id/client_secret) once per org. +model McpServer { + id String @id @default(cuid()) + name String /// Org-approved display name (e.g., "Linear") + sanitizedName String /// Stable tool-name prefix (e.g., "linear") + serverUrl String /// MCP server endpoint (e.g., "https://mcp.linear.app/mcp") + + /// Dynamic client registration result (RFC 7591) or admin-provided static OAuth client credentials. + /// Encrypted JSON of OAuthClientInformation: { client_id, client_secret, client_id_issued_at, client_secret_expires_at } + /// Null for DYNAMIC rows until first user in the org triggers registration. + clientInfo String? + clientInfoSource McpServerClientInfoSource @default(DYNAMIC) + + org Org @relation(fields: [orgId], references: [id], onDelete: Cascade) + orgId Int + + userMcpServers UserMcpServer[] + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([serverUrl, orgId]) + @@unique([orgId, sanitizedName]) +} + +/// A user's personal connection to an MCP server. +/// Stores per-user OAuth tokens and ephemeral auth-flow state. +model UserMcpServer { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + + server McpServer @relation(fields: [serverId], references: [id], onDelete: Cascade) + serverId String + + /// OAuth tokens (access_token, refresh_token, etc.) — encrypted JSON of OAuthTokens. + tokens String? + + /// Absolute expiry time of the access token, computed at issuance from expires_in. + /// Null when no tokens are stored or the provider did not include expires_in. + tokensExpiresAt DateTime? + + /// PKCE code verifier — ephemeral, only used between redirect and callback. + codeVerifier String? + + /// OAuth state parameter — ephemeral, for CSRF protection during auth flow. + state String? + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@id([userId, serverId]) + @@index([serverId]) + @@index([state]) +} diff --git a/packages/shared/src/env.server.ts b/packages/shared/src/env.server.ts index 036655018..21d5e2b37 100644 --- a/packages/shared/src/env.server.ts +++ b/packages/shared/src/env.server.ts @@ -278,6 +278,7 @@ const options = { */ SOURCEBOT_CHAT_MODEL_TEMPERATURE: numberSchema.optional(), SOURCEBOT_CHAT_MAX_STEP_COUNT: numberSchema.default(100), + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: numberSchema.default(60000), DEBUG_WRITE_CHAT_MESSAGES_TO_FILE: booleanSchema.default('false'), DEBUG_ENABLE_REACT_SCAN: booleanSchema.default('false'), diff --git a/packages/web/package.json b/packages/web/package.json index 8825a21c5..9ca986d6a 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -20,6 +20,7 @@ "@ai-sdk/deepseek": "^2.0.29", "@ai-sdk/google": "^3.0.64", "@ai-sdk/google-vertex": "^4.0.111", + "@ai-sdk/mcp": "^2.0.0-beta.11", "@ai-sdk/mistral": "^3.0.30", "@ai-sdk/openai": "^3.0.53", "@ai-sdk/openai-compatible": "^2.0.41", @@ -196,7 +197,7 @@ "use-stick-to-bottom": "^1.1.3", "usehooks-ts": "^3.1.0", "vscode-icons-js": "^11.6.1", - "zod": "^3.25.74", + "zod": "^3.25.76", "zod-to-json-schema": "^3.24.5" }, "devDependencies": { diff --git a/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx b/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx index 862298a61..9a560199e 100644 --- a/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx @@ -16,6 +16,7 @@ import { type LucideIcon, PlugIcon, ScrollTextIcon, + ServerIcon, Settings2Icon, ShieldIcon, UserIcon, @@ -32,6 +33,7 @@ const iconMap = { "plug": PlugIcon, "chart-area": ChartAreaIcon, "scroll-text": ScrollTextIcon, + "server": ServerIcon, "settings": Settings2Icon, "user": UserIcon, } satisfies Record; diff --git a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx index 43ccb1a87..6bc248ce8 100644 --- a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx +++ b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx @@ -69,7 +69,7 @@ export const LandingPage = ({
{ - createNewChatThread(children, selectedSearchScopes); + createNewChatThread(children, selectedSearchScopes, []); }} className="min-h-[50px]" isRedirecting={isLoading} diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx index 574001e5f..3cf15df48 100644 --- a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx +++ b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx @@ -40,11 +40,13 @@ export const ChatThreadPanel = ({ localStorage.removeItem(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY); }, []); - // Use the last user's last message to determine what repos and contexts we should select by default. + // Use the last user message to determine what repos, contexts, and MCP state we should select by default. const lastUserMessage = messages.findLast((message) => message.role === "user"); const defaultSelectedSearchScopes = lastUserMessage?.metadata?.selectedSearchScopes ?? []; + const defaultDisabledMcpServerIds = lastUserMessage?.metadata?.disabledMcpServerIds ?? []; const [selectedSearchScopes, setSelectedSearchScopes] = useState(defaultSelectedSearchScopes); - + const [disabledMcpServerIds, setDisabledMcpServerIds] = useState(defaultDisabledMcpServerIds); + useEffect(() => { if (!chatState) { return; @@ -53,6 +55,7 @@ export const ChatThreadPanel = ({ try { setInputMessage(chatState.inputMessage); setSelectedSearchScopes(chatState.selectedSearchScopes); + setDisabledMcpServerIds(chatState.disabledMcpServerIds); } catch { console.error('Invalid chat state in session storage'); } finally { @@ -72,6 +75,8 @@ export const ChatThreadPanel = ({ searchContexts={searchContexts} selectedSearchScopes={selectedSearchScopes} onSelectedSearchScopesChange={setSelectedSearchScopes} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} isOwner={isOwner} isAuthenticated={isAuthenticated} chatName={chatName} diff --git a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx index 9d6b92381..99c2a5fb7 100644 --- a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx +++ b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx @@ -8,7 +8,7 @@ import { useCreateNewChatThread } from "@/features/chat/useCreateNewChatThread"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useState } from "react"; import { useLocalStorage } from "usehooks-ts"; -import { SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; +import { DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; import { SearchModeSelector } from "../../components/searchModeSelector"; import { NotConfiguredErrorBanner } from "@/features/chat/components/notConfiguredErrorBanner"; import { LoginModal } from "@/app/components/loginModal"; @@ -28,6 +28,7 @@ export const LandingPageChatBox = ({ }: LandingPageChatBox) => { const { createNewChatThread, isLoading, loginWall } = useCreateNewChatThread({ isAuthenticated }); const [selectedSearchScopes, setSelectedSearchScopes] = useLocalStorage(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); + const [disabledMcpServerIds, setDisabledMcpServerIds] = useLocalStorage(DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); const [isContextSelectorOpen, setIsContextSelectorOpen] = useState(false); const isChatBoxDisabled = languageModels.length === 0; @@ -36,7 +37,7 @@ export const LandingPageChatBox = ({
{ - createNewChatThread(children, selectedSearchScopes); + createNewChatThread(children, selectedSearchScopes, disabledMcpServerIds); }} className="min-h-[50px]" isRedirecting={isLoading} @@ -56,6 +57,8 @@ export const LandingPageChatBox = ({ onSelectedSearchScopesChange={setSelectedSearchScopes} isContextSelectorOpen={isContextSelectorOpen} onContextSelectorOpenChanged={setIsContextSelectorOpen} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} /> { + if (didHandleStatusRef.current) { + return; + } + + const status = searchParams.get('status'); + if (status !== 'connected' && status !== 'error') { + return; + } + + didHandleStatusRef.current = true; + const server = searchParams.get('server'); + const message = searchParams.get('message'); + + if (status === 'connected') { + toast({ description: `Successfully connected${server ? ` to ${server}` : ''}.` }); + } else { + toast({ + title: "Connection failed", + description: message ?? 'Failed to connect MCP server.', + variant: "destructive", + }); + } + + const nextSearchParams = new URLSearchParams(searchParams.toString()); + nextSearchParams.delete('status'); + nextSearchParams.delete('server'); + nextSearchParams.delete('message'); + + const query = nextSearchParams.toString(); + router.replace(`${pathname}${query ? `?${query}` : ''}`, { scroll: false }); + }, [pathname, router, searchParams, toast]); + + return null; +} diff --git a/packages/web/src/app/(app)/chat/layout.tsx b/packages/web/src/app/(app)/chat/layout.tsx index 6f2094209..b4bdcdda5 100644 --- a/packages/web/src/app/(app)/chat/layout.tsx +++ b/packages/web/src/app/(app)/chat/layout.tsx @@ -1,6 +1,8 @@ import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME } from '@/lib/constants'; import { NavigationGuardProvider } from 'next-navigation-guard'; import { cookies } from 'next/headers'; +import { Suspense } from 'react'; +import { McpOAuthStatusToast } from './components/mcpOAuthStatusToast'; import { TutorialDialog } from './components/tutorialDialog'; interface LayoutProps { @@ -14,8 +16,11 @@ export default async function Layout({ children }: LayoutProps) { // @note: we use a navigation guard here since we don't support resuming streams yet. // @see: https://ai-sdk.dev/docs/ai-sdk-ui/chatbot-message-persistence#resuming-ongoing-streams + + + {children} ) -} \ No newline at end of file +} diff --git a/packages/web/src/app/(app)/settings/layout.tsx b/packages/web/src/app/(app)/settings/layout.tsx index a03ef575a..489063843 100644 --- a/packages/web/src/app/(app)/settings/layout.tsx +++ b/packages/web/src/app/(app)/settings/layout.tsx @@ -44,7 +44,7 @@ export default async function SettingsLayout( } export const getSidebarNavGroups = async () => - withAuth(async ({ role }) => { + withAuth(async ({ org, role, prisma }) => { let numJoinRequests: number | undefined; if (role === OrgRole.OWNER) { const requests = await getOrgAccountRequests(); @@ -58,6 +58,12 @@ export const getSidebarNavGroups = async () => if (isServiceError(connectionStats)) { throw new ServiceErrorException(connectionStats); } + const hasOAuthEntitlement = await hasEntitlement("oauth"); + const hasApprovedMcpServers = role === OrgRole.OWNER && !hasOAuthEntitlement + ? await prisma.mcpServer.count({ + where: { orgId: org.id }, + }) > 0 + : false; const groups: NavGroup[] = [ { @@ -82,6 +88,12 @@ export const getSidebarNavGroups = async () => icon: "link" as const, } ] : []), + ...(hasOAuthEntitlement ? [ + { + title: "MCP Servers", + href: `/settings/mcpServers`, + } + ] : []), ], }, ]; @@ -113,6 +125,13 @@ export const getSidebarNavGroups = async () => href: `/settings/analytics`, icon: "chart-area" as const, }, + ...(hasOAuthEntitlement || hasApprovedMcpServers ? [ + { + title: "MCP Configuration", + href: `/settings/mcpConfiguration`, + icon: "server" as const, + } + ] : []), { title: "License", href: `/settings/license`, @@ -123,4 +142,4 @@ export const getSidebarNavGroups = async () => } return groups.filter(g => g.items.length > 0); - }); \ No newline at end of file + }); diff --git a/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationPage.tsx b/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationPage.tsx new file mode 100644 index 000000000..dfd50c929 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationPage.tsx @@ -0,0 +1,441 @@ +'use client'; + +import { useState } from "react"; +import { getMcpConfiguration } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { + Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Skeleton } from "@/components/ui/skeleton"; +import { checkMcpServerDynamicClientRegistration, createMcpServer, createStaticOAuthMcpServer, deleteMcpServer } from "@/ee/features/mcp/actions"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/mcp/queryKeys"; +import { isServiceError } from "@/lib/utils"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { AlertTriangleIcon, Loader2, MinusIcon, PlusIcon, ServerIcon } from "lucide-react"; +import { PrefabMcpServerPopover } from "./prefabMcpServerPopover"; +import type { PrefabMcpServer } from "@/ee/features/mcp/prefabMcpServers"; + +function pluralize(count: number, singular: string, plural = `${singular}s`) { + return count === 1 ? singular : plural; +} + +export function McpConfigurationPage() { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [newServerName, setNewServerName] = useState(""); + const [newServerUrl, setNewServerUrl] = useState(""); + const [isClientCredentialsDialogOpen, setIsClientCredentialsDialogOpen] = useState(false); + const [pendingClientCredentialsServer, setPendingClientCredentialsServer] = useState<{ name: string; serverUrl: string } | null>(null); + const [clientId, setClientId] = useState(""); + const [clientSecret, setClientSecret] = useState(""); + const [isCreating, setIsCreating] = useState(false); + const [deletingServerId, setDeletingServerId] = useState(null); + + const { data, isLoading, isError } = useQuery({ + queryKey: mcpQueryKeys.configuration, + queryFn: async () => { + const result = await getMcpConfiguration(); + if (isServiceError(result)) { + throw new Error(result.message); + } + return result; + }, + }); + + const servers = data?.servers ?? []; + const totalSavedConnectionCount = data?.totalSavedConnectionCount ?? 0; + const canCreateMcpServers = data?.isOAuthAvailable === true; + const isOAuthUnavailable = data?.isOAuthAvailable === false; + + const handleCreateDialogOpenChange = (open: boolean) => { + setIsCreateDialogOpen(open); + + if (!open) { + setNewServerName(""); + setNewServerUrl(""); + } + }; + + const handleCloseCreateDialog = () => { + handleCreateDialogOpenChange(false); + }; + + const handleCloseClientCredentialsDialog = () => { + setIsClientCredentialsDialogOpen(false); + setPendingClientCredentialsServer(null); + setClientId(""); + setClientSecret(""); + }; + + const handleOpenCustomUrlDialog = () => { + setNewServerName(""); + setNewServerUrl(""); + setIsCreateDialogOpen(true); + }; + + const handleCreateStaticOAuthServer = async () => { + if (!pendingClientCredentialsServer) { + toast({ title: "Error", description: "Missing MCP server details", variant: "destructive" }); + return; + } + + if (process.env.NODE_ENV === "production" && window.location.protocol !== "https:") { + toast({ + title: "HTTPS required", + description: "Static OAuth client credentials can only be submitted over HTTPS in production.", + variant: "destructive", + }); + return; + } + + setIsCreating(true); + try { + const result = await createStaticOAuthMcpServer({ + name: pendingClientCredentialsServer.name, + serverUrl: pendingClientCredentialsServer.serverUrl, + clientId, + clientSecret, + }); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to add MCP server: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + handleCloseClientCredentialsDialog(); + } catch { + toast({ title: "Error", description: "Failed to add MCP server.", variant: "destructive" }); + } finally { + setIsCreating(false); + } + }; + + const handleCreateServer = async ( + name: string, + serverUrl: string, + onSuccess?: () => void, + options: { checkDynamicClientRegistration?: boolean } = {}, + ) => { + const displayName = name.trim(); + const normalizedServerUrl = serverUrl.trim(); + + if (!displayName || !normalizedServerUrl) { + toast({ title: "Error", description: "Name and server URL are required", variant: "destructive" }); + return; + } + + setIsCreating(true); + try { + if (options.checkDynamicClientRegistration) { + const dcrSupport = await checkMcpServerDynamicClientRegistration(normalizedServerUrl); + if (isServiceError(dcrSupport)) { + toast({ title: "Error", description: `Failed to check MCP server: ${dcrSupport.message}`, variant: "destructive" }); + return; + } + + if (dcrSupport.isKnown && !dcrSupport.supportsDcr) { + setPendingClientCredentialsServer({ name: displayName, serverUrl: normalizedServerUrl }); + setIsCreateDialogOpen(false); + setIsClientCredentialsDialogOpen(true); + return; + } + } + + const result = await createMcpServer(displayName, normalizedServerUrl); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to add MCP server: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + onSuccess?.(); + } catch (error) { + toast({ title: "Error", description: `Failed to add MCP server: ${error}`, variant: "destructive" }); + } finally { + setIsCreating(false); + } + }; + + const handleCreate = async () => { + await handleCreateServer(newServerName, newServerUrl, handleCloseCreateDialog, { + checkDynamicClientRegistration: true, + }); + }; + + const handleCreatePrefabServer = async (server: PrefabMcpServer) => { + await handleCreateServer(server.name, server.serverUrl, undefined, { + checkDynamicClientRegistration: true, + }); + }; + + const handleDelete = async (serverId: string) => { + setDeletingServerId(serverId); + try { + const result = await deleteMcpServer(serverId); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to delete MCP server: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + } catch (error) { + toast({ title: "Error", description: `Failed to delete MCP server: ${error}`, variant: "destructive" }); + } finally { + setDeletingServerId(null); + } + }; + + if (isError) { + return
Error loading MCP configuration
; + } + + return ( +
+
+

MCP Configuration

+

+ Configure the MCP servers that workspace members can connect to. +

+
+ + {!isLoading && isOAuthUnavailable && ( + + + +
+

OAuth MCP is unavailable

+

+ You can remove existing approved servers and stored credentials, but cannot add new MCP servers. +

+
+
+
+ )} + + + +
+
+

Saved MCP connections

+

+ Current workspace members with saved MCP server credentials. +

+
+ {isLoading ? ( + + ) : ( +

+ {totalSavedConnectionCount} {pluralize(totalSavedConnectionCount, "connection")} +

+ )} +
+
+
+ + + +
+ {isLoading ? "Allowed servers" : `${servers.length} allowed ${pluralize(servers.length, "server")}`} + + {isOAuthUnavailable + ? "Remove existing server approvals and their stored credentials." + : "Approve server URLs that workspace members can connect to."} + +
+ {canCreateMcpServers ? ( + <> + server.serverUrl)} + disabled={isCreating} + onSelectCustomUrl={handleOpenCustomUrlDialog} + onSelectPrefabServer={handleCreatePrefabServer} + /> + + + + Add MCP Server + + Add a workspace-approved MCP server that members can connect to from Ask Sourcebot. + + +
+
+ + setNewServerName(event.target.value)} + placeholder="e.g. Linear" + /> +
+
+ + setNewServerUrl(event.target.value)} + placeholder="https://mcp.linear.app/mcp" + /> +
+
+ + + + +
+
+ { + if (!open) { + handleCloseClientCredentialsDialog(); + return; + } + + setIsClientCredentialsDialogOpen(true); + }}> + + + OAuth Client Credentials Required + + This MCP server does not advertise dynamic client registration. Provide OAuth client credentials from a pre-registered app before members can connect to it. + + +
+ {pendingClientCredentialsServer && ( +
+

{pendingClientCredentialsServer.name}

+

{pendingClientCredentialsServer.serverUrl}

+
+ )} +
+ + setClientId(event.target.value)} + placeholder="OAuth client ID" + /> +
+
+ + setClientSecret(event.target.value)} + placeholder="OAuth client secret" + /> +
+
+ + + + +
+
+ + ) : ( + + )} +
+ + {isLoading ? ( +
+ {Array.from({ length: 3 }).map((_, index) => ( +
+ +
+ + +
+ +
+ ))} +
+ ) : servers.length === 0 ? ( +
+
+ +
+

No MCP servers configured yet

+

+ {isOAuthUnavailable + ? "OAuth MCP is unavailable on this Sourcebot instance." + : "Add a workspace-approved MCP server so members can connect it to Ask Sourcebot."} +

+
+ ) : ( +
+ {servers.map((server) => ( +
+
+ +
+
+

{server.name || server.serverUrl}

+

{server.serverUrl}

+
+

+ {server.savedConnectionCount} {pluralize(server.savedConnectionCount, "saved connection")} +

+ + + + + + + Delete MCP Server + + Are you sure you want to remove {server.name || server.serverUrl}? Workspace members will lose access and stored credentials for this server. + + + + Cancel + handleDelete(server.id)} + disabled={deletingServerId === server.id} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + {deletingServerId === server.id ? "Deleting..." : "Delete"} + + + + +
+ ))} +
+ )} +
+
+
+ ); +} diff --git a/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationUnavailableMessage.tsx b/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationUnavailableMessage.tsx new file mode 100644 index 000000000..6ef7ded41 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpConfiguration/mcpConfigurationUnavailableMessage.tsx @@ -0,0 +1,29 @@ +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { ServerIcon } from "lucide-react"; + +export function McpConfigurationUnavailableMessage() { + return ( +
+ + +
+
+ +
+
+ + MCP Configuration Is Unavailable + + + OAuth-backed MCP servers are not supported on this Sourcebot instance. + +
+ +

+ Use Sourcebot API keys for MCP access on this deployment. +

+
+
+
+ ); +} diff --git a/packages/web/src/app/(app)/settings/mcpConfiguration/page.test.tsx b/packages/web/src/app/(app)/settings/mcpConfiguration/page.test.tsx new file mode 100644 index 000000000..f349a072a --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpConfiguration/page.test.tsx @@ -0,0 +1,66 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, render, screen } from '@testing-library/react'; +import type React from 'react'; + +const mocks = vi.hoisted(() => ({ + authContext: { + org: { id: 1 }, + prisma: { + mcpServer: { + count: vi.fn(), + }, + }, + }, + hasEntitlement: vi.fn(), +})); + +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/authenticatedPage', () => ({ + authenticatedPage: vi.fn((page: (auth: typeof mocks.authContext) => Promise) => () => page(mocks.authContext)), +})); +vi.mock('./mcpConfigurationPage', () => ({ + McpConfigurationPage: () =>
MCP configuration client
, +})); + +const { default: Page } = await import('./page'); + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.authContext.prisma.mcpServer.count.mockResolvedValue(0); +}); + +afterEach(() => { + cleanup(); +}); + +describe('MCP configuration settings page', () => { + test('renders the client configuration page when OAuth is available', async () => { + render(await Page({})); + + expect(screen.getByText('MCP configuration client')).toBeTruthy(); + }); + + test('renders the client configuration page when OAuth is unavailable but servers exist for cleanup', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + mocks.authContext.prisma.mcpServer.count.mockResolvedValue(1); + + render(await Page({})); + + expect(screen.getByText('MCP configuration client')).toBeTruthy(); + expect(mocks.authContext.prisma.mcpServer.count).toHaveBeenCalledWith({ + where: { orgId: 1 }, + }); + }); + + test('renders an unavailable message when OAuth is not available and no cleanup is needed', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + + render(await Page({})); + + expect(screen.getByText('MCP Configuration Is Unavailable')).toBeTruthy(); + expect(screen.queryByText('MCP configuration client')).toBeNull(); + }); +}); diff --git a/packages/web/src/app/(app)/settings/mcpConfiguration/page.tsx b/packages/web/src/app/(app)/settings/mcpConfiguration/page.tsx new file mode 100644 index 000000000..c6c1015f5 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpConfiguration/page.tsx @@ -0,0 +1,19 @@ +import { hasEntitlement } from "@/lib/entitlements"; +import { authenticatedPage } from "@/middleware/authenticatedPage"; +import { OrgRole } from "@sourcebot/db"; +import { McpConfigurationPage } from "./mcpConfigurationPage"; +import { McpConfigurationUnavailableMessage } from "./mcpConfigurationUnavailableMessage"; + +export default authenticatedPage(async ({ org, prisma }) => { + if (!(await hasEntitlement("oauth"))) { + const serverCount = await prisma.mcpServer.count({ + where: { orgId: org.id }, + }); + + if (serverCount === 0) { + return ; + } + } + + return ; +}, { minRole: OrgRole.OWNER, redirectTo: '/settings' }); diff --git a/packages/web/src/app/(app)/settings/mcpConfiguration/prefabMcpServerPopover.tsx b/packages/web/src/app/(app)/settings/mcpConfiguration/prefabMcpServerPopover.tsx new file mode 100644 index 000000000..f09ba07c9 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpConfiguration/prefabMcpServerPopover.tsx @@ -0,0 +1,133 @@ +'use client'; + +import { useMemo, useState } from "react"; +import { + Command, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + CommandSeparator, +} from "@/components/ui/command"; +import { Button } from "@/components/ui/button"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { + getAvailablePrefabMcpServers, + type PrefabMcpServer, +} from "@/ee/features/mcp/prefabMcpServers"; +import { getMcpFaviconUrl } from "@/ee/features/mcp/utils"; +import { PlusIcon } from "lucide-react"; + +interface PrefabMcpServerPopoverProps { + configuredServerUrls: string[]; + disabled?: boolean; + onSelectCustomUrl: () => void; + onSelectPrefabServer: (server: PrefabMcpServer) => void; +} + +function getDisplayServerUrl(serverUrl: string) { + try { + const url = new URL(serverUrl); + return `${url.host}${url.pathname}${url.search}`.replace(/\/$/, ""); + } catch { + return serverUrl; + } +} + +export function PrefabMcpServerPopover({ + configuredServerUrls, + disabled, + onSelectCustomUrl, + onSelectPrefabServer, +}: PrefabMcpServerPopoverProps) { + const [isOpen, setIsOpen] = useState(false); + const [search, setSearch] = useState(""); + + const availablePrefabServers = useMemo(() => ( + getAvailablePrefabMcpServers(configuredServerUrls) + ), [configuredServerUrls]); + + const filteredPrefabServers = useMemo(() => { + const normalizedSearch = search.trim().toLowerCase(); + + if (!normalizedSearch) { + return availablePrefabServers; + } + + return availablePrefabServers.filter((server) => server.name.toLowerCase().includes(normalizedSearch)); + }, [availablePrefabServers, search]); + + const handleOpenChange = (open: boolean) => { + setIsOpen(open); + + if (!open) { + setSearch(""); + } + }; + + const handleSelectPrefabServer = (server: PrefabMcpServer) => { + handleOpenChange(false); + onSelectPrefabServer(server); + }; + + const handleSelectCustomUrl = () => { + handleOpenChange(false); + onSelectCustomUrl(); + }; + + return ( + + + + + + + + + + {filteredPrefabServers.map((server) => ( + handleSelectPrefabServer(server)} + className="cursor-pointer" + > +
+ +
+
+

{server.name}

+

{getDisplayServerUrl(server.serverUrl)}

+
+
+ ))} + {search.trim() && filteredPrefabServers.length === 0 && ( +
+ No servers found. +
+ )} +
+ + + + + Custom URL... + + +
+
+
+
+ ); +} diff --git a/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.test.tsx b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.test.tsx new file mode 100644 index 000000000..6f9221d80 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.test.tsx @@ -0,0 +1,25 @@ +import { afterEach, describe, expect, test } from 'vitest'; +import { cleanup, render, screen } from '@testing-library/react'; +import { McpServersEmptyState } from './mcpServersPage'; + +afterEach(() => { + cleanup(); +}); + +describe('McpServersEmptyState', () => { + test('points owners to workspace MCP configuration', () => { + render(); + + expect(screen.getByText('No MCP servers configured yet')).toBeTruthy(); + expect(screen.getByText(/Go to Workspace MCP Configuration/)).toBeTruthy(); + expect(screen.getByRole('link', { name: /Open MCP Configuration/ }).getAttribute('href')).toBe('/settings/mcpConfiguration'); + }); + + test('tells members to contact an admin', () => { + render(); + + expect(screen.getByText('No MCP servers available')).toBeTruthy(); + expect(screen.getByText(/Contact your workspace admin/)).toBeTruthy(); + expect(screen.queryByRole('link', { name: /Open MCP Configuration/ })).toBeNull(); + }); +}); diff --git a/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx new file mode 100644 index 000000000..d06239bfb --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx @@ -0,0 +1,424 @@ +'use client'; + +import { useEffect, useMemo, useRef, useState } from "react"; +import Link from "next/link"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { ExternalLink, MoreHorizontal, SearchIcon, ServerIcon, Settings2Icon, Unplug } from "lucide-react"; +import { getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { + DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Skeleton } from "@/components/ui/skeleton"; +import { ConnectMcpButton } from "@/ee/features/mcp/components/connectMcpButton"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { useConnectMcp } from "@/ee/features/mcp/hooks/useConnectMcp"; +import { disconnectMcpServer } from "@/ee/features/mcp/actions"; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/mcp/queryKeys"; +import { cn, isServiceError } from "@/lib/utils"; + +type FilterTab = "all" | "connected"; + +function displayUrl(url: string) { + return url.replace(/^https?:\/\//, ""); +} + +function pluralize(count: number, singular: string, plural = `${singular}s`) { + return count === 1 ? singular : plural; +} + +function clearCallbackParams() { + const url = new URL(window.location.href); + url.searchParams.delete('status'); + url.searchParams.delete('server'); + url.searchParams.delete('message'); + window.history.replaceState({}, '', url.toString()); +} + +interface McpServersPageProps { + callbackStatus?: string; + callbackServer?: string; + callbackMessage?: string; + canManageMcpServers: boolean; +} + +export function McpServersEmptyState({ canManageMcpServers }: { canManageMcpServers: boolean }) { + return ( + + +
+ +
+

+ {canManageMcpServers ? "No MCP servers configured yet" : "No MCP servers available"} +

+

+ {canManageMcpServers + ? "Go to Workspace MCP Configuration to add servers before connecting them to Ask Sourcebot." + : "No MCP servers have been approved for this workspace yet. Contact your workspace admin."} +

+ {canManageMcpServers && ( + + )} +
+
+ ); +} + +export function McpServersPage({ callbackStatus, callbackServer, callbackMessage, canManageMcpServers }: McpServersPageProps) { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const didHandleCallbackRef = useRef(false); + const [searchQuery, setSearchQuery] = useState(""); + const [activeTab, setActiveTab] = useState("all"); + const [disconnectingServerId, setDisconnectingServerId] = useState(null); + const [confirmDisconnectServer, setConfirmDisconnectServer] = useState<{ id: string; name: string } | null>(null); + const { connect: reconnectMcp } = useConnectMcp(); + + useEffect(() => { + if (didHandleCallbackRef.current) { + return; + } + if (callbackStatus === 'connected') { + didHandleCallbackRef.current = true; + toast({ description: `Successfully connected${callbackServer ? ` to ${callbackServer}` : ''}.` }); + clearCallbackParams(); + } else if (callbackStatus === 'error') { + didHandleCallbackRef.current = true; + toast({ title: "Connection failed", description: callbackMessage ?? 'Failed to connect MCP server.', variant: "destructive" }); + clearCallbackParams(); + } + }, [callbackStatus, callbackServer, callbackMessage, toast]); + + const { data: servers = [], isLoading, isError } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load MCP servers"); + } + return result; + }, + }); + + const connectedServers = useMemo( + () => servers.filter((s) => s.isConnected || s.isAuthExpired), + [servers], + ); + + const suggestedServers = useMemo( + () => servers.filter((s) => !s.isConnected && !s.isAuthExpired), + [servers], + ); + + const filteredConnected = useMemo(() => { + const list = connectedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [connectedServers, searchQuery]); + + const filteredSuggested = useMemo(() => { + const list = suggestedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [suggestedServers, searchQuery]); + + const visibleConnected = filteredConnected; + const visibleSuggested = activeTab === "all" ? filteredSuggested : []; + + const handleDisconnect = async (serverId: string) => { + setDisconnectingServerId(serverId); + setConfirmDisconnectServer(null); + try { + const result = await disconnectMcpServer(serverId); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to disconnect: ${result.message}`, variant: "destructive" }); + return; + } + toast({ description: "MCP server disconnected." }); + await invalidateMcpConfigurationQueries(queryClient); + } catch { + toast({ title: "Error", description: "Failed to disconnect MCP server.", variant: "destructive" }); + } finally { + setDisconnectingServerId(null); + } + }; + + if (isError) { + return
Error loading MCP servers
; + } + + if (!isLoading && servers.length === 0) { + return ( +
+
+

MCP Servers

+

+ Connect to workspace-approved MCP servers to use them with Ask Sourcebot. +

+
+ +
+ ); + } + + return ( +
+
+

MCP Servers

+

+ Connect to workspace-approved MCP servers to use them with Ask Sourcebot. +

+
+ + {/* Search + filter bar */} +
+
+ + setSearchQuery(e.target.value)} + className="pl-9" + /> +
+
+ + +
+
+ + {isLoading ? ( +
+ {Array.from({ length: 3 }).map((_, index) => ( + + + +
+ + +
+ +
+
+ ))} +
+ ) : ( + <> + {/* Connected section */} +
+
+

+ Connected +

+

+ {connectedServers.length} {pluralize(connectedServers.length, "server")} +

+
+ + {visibleConnected.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No connected servers match your search." + : "No servers connected yet."} +

+
+
+ ) : ( + visibleConnected.map((server) => ( + + +
+ +
+
+

+ {server.name || server.serverUrl} +

+

+ {displayUrl(server.serverUrl)} +

+
+ {server.isConnected && ( + <> + + Connected + + )} + {server.isAuthExpired && ( + <> + + Authorization expired + + )} +
+
+
+ + + + + + + reconnectMcp(server.id)}> + + Reconnect + + setConfirmDisconnectServer({ + id: server.id, + name: server.name || server.serverUrl, + })} + > + + {disconnectingServerId === server.id ? "Disconnecting..." : "Disconnect"} + + + +
+
+
+ )) + )} +
+ + {/* Suggested section */} + {activeTab === "all" && ( +
+
+

+ Suggested +

+

+ workspace-approved +

+
+ + {visibleSuggested.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No suggested servers match your search." + : "All servers are connected."} +

+
+
+ ) : ( + visibleSuggested.map((server) => ( + + +
+ +
+
+

+ {server.name || server.serverUrl} +

+

+ {displayUrl(server.serverUrl)} +

+
+ +
+
+ )) + )} +
+ )} + + )} + + {/* Disconnect confirmation dialog */} + { + if (!open) { + setConfirmDisconnectServer(null); + } + }} + > + + + Disconnect MCP Server + + Are you sure you want to disconnect from {confirmDisconnectServer?.name}? Your stored credentials for this server will be removed. + + + + Cancel + { + if (confirmDisconnectServer) { + handleDisconnect(confirmDisconnectServer.id); + } + }} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + Disconnect + + + + +
+ ); +} diff --git a/packages/web/src/app/(app)/settings/mcpServers/page.tsx b/packages/web/src/app/(app)/settings/mcpServers/page.tsx new file mode 100644 index 000000000..7c6c43ad1 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpServers/page.tsx @@ -0,0 +1,23 @@ +import { McpServersPage } from "./mcpServersPage"; +import { authenticatedPage } from "@/middleware/authenticatedPage"; +import { OrgRole } from "@sourcebot/db"; + +interface PageProps extends Record { + searchParams: Promise<{ + status?: string; + server?: string; + message?: string; + }>; +} + +export default authenticatedPage(async ({ role }, { searchParams }) => { + const { status, server, message } = await searchParams; + return ( + + ); +}); diff --git a/packages/web/src/app/api/(client)/client.ts b/packages/web/src/app/api/(client)/client.ts index 22c689278..1b6b8573c 100644 --- a/packages/web/src/app/api/(client)/client.ts +++ b/packages/web/src/app/api/(client)/client.ts @@ -29,6 +29,9 @@ import type { SearchChatShareableMembersQueryParams, SearchChatShareableMembersResponse, } from "../(server)/ee/chat/[chatId]/searchMembers/route"; +import { ConnectMcpResponse } from "../(server)/ee/askmcp/connect/types"; +import type { GetMcpServersResponse } from "../(server)/ee/askmcp/servers/route"; +import type { GetMcpConfigurationResponse } from "@/ee/features/mcp/types"; export const search = async (body: SearchRequest): Promise => { const result = await fetch("/api/search", { @@ -214,4 +217,44 @@ export const listChats = async (queryParams: ListChatsQueryParams): Promise response.json()); return result as ListChatsResponse | ServiceError; -} \ No newline at end of file +} + +export const connectMcpToAsk = async (body: { serverId: string; returnTo?: string }): Promise => { + const result = await fetch('/api/ee/askmcp/connect', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: JSON.stringify(body), + }).then(response => response.json()); + + if (isServiceError(result)) { + return result; + } + + return result as ConnectMcpResponse; +} + +export const getMcpServersWithStatus = async (): Promise => { + const result = await fetch('/api/ee/askmcp/servers', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpServersResponse | ServiceError; +} + +export const getMcpConfiguration = async (): Promise => { + const result = await fetch('/api/ee/askmcp/configuration', { + method: 'GET', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpConfigurationResponse | ServiceError; +} diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index 4c0b12819..5953cbe0d 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -33,7 +33,7 @@ export const POST = apiHandler(async (req: NextRequest) => { return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); } - const { messages, id, selectedSearchScopes, languageModel: _languageModel } = parsed.data; + const { messages, id, selectedSearchScopes, disabledMcpServerIds, languageModel: _languageModel } = parsed.data; // @note: a bit of type massaging is required here since the // zod schema does not enum on `model` or `provider`. // @see: chat/types.ts @@ -108,10 +108,14 @@ export const POST = apiHandler(async (req: NextRequest) => { selectedSearchScopes, }, selectedRepos: expandedRepos, + prisma, + disabledMcpServerIds, model, modelName: languageModelConfig.displayName ?? languageModelConfig.model, modelProviderOptions: providerOptions, modelTemperature: temperature, + userId: user?.id, + orgId: org.id, onFinish: async ({ messages }) => { await updateChatMessages({ chatId: id, messages, prisma }); }, diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts new file mode 100644 index 000000000..31ce476fd --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts @@ -0,0 +1,186 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; + +const mocks = vi.hoisted(() => ({ + auth: vi.fn(), + hasEntitlement: vi.fn(), + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + mcpAuth: vi.fn(), + unsafePrisma: { + mcpServer: { + updateMany: vi.fn(), + }, + userMcpServer: { + findFirst: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + userToOrg: { + findUnique: vi.fn(), + }, + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/auth', () => ({ + auth: mocks.auth, +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/prisma', () => ({ + prisma: mocks.unsafePrisma, + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + env: { + AUTH_URL: 'https://sourcebot.example.com', + }, + createLogger: () => mocks.logger, + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), +})); +vi.mock('@ai-sdk/mcp', () => ({ + auth: mocks.mcpAuth, +})); + +const { GET } = await import('./route'); +const { createMcpOAuthState } = await import('@/features/mcp/mcpOAuthReturnTo'); + +function createRequest(state = 'state-1') { + return new NextRequest(`https://sourcebot.example.com/api/ee/askmcp/callback?code=code-1&state=${encodeURIComponent(state)}`, { + method: 'GET', + }); +} + +function createOAuthErrorRequest(state: string) { + return new NextRequest(`https://sourcebot.example.com/api/ee/askmcp/callback?error=access_denied&error_description=Denied&state=${encodeURIComponent(state)}`, { + method: 'GET', + }); +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.auth.mockResolvedValue({ user: { id: 'user-1' } }); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.unsafePrisma.userMcpServer.findFirst.mockResolvedValue({ + serverId: 'server-1', + server: { + orgId: 1, + name: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + }); + mocks.unsafePrisma.userMcpServer.update.mockResolvedValue({ userId: 'user-1', serverId: 'server-1' }); + mocks.unsafePrisma.userToOrg.findUnique.mockResolvedValue({ orgId: 1, userId: 'user-1' }); +}); + +describe('GET /api/ee/askmcp/callback', () => { + test('redirects successful chat-originated auth back to chat', async () => { + const state = createMcpOAuthState('state-1', '/chat'); + mocks.mcpAuth.mockResolvedValue('AUTHORIZED'); + + const response = await GET(createRequest(state)); + const location = response.headers.get('location'); + const url = new URL(location ?? ''); + + expect(url.pathname).toBe('/chat'); + expect(url.searchParams.get('status')).toBe('connected'); + expect(url.searchParams.get('server')).toBe('Linear'); + expect(mocks.unsafePrisma.userMcpServer.findFirst).toHaveBeenCalledWith({ + where: { + state, + userId: 'user-1', + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + }, + }, + }, + }); + }); + + test('redirects denied chat-originated auth back to chat', async () => { + const state = createMcpOAuthState('state-1', '/chat'); + + const response = await GET(createOAuthErrorRequest(state)); + const url = new URL(response.headers.get('location') ?? ''); + + expect(url.pathname).toBe('/chat'); + expect(url.searchParams.get('status')).toBe('error'); + expect(url.searchParams.get('message')).toBe('Denied'); + expect(mocks.mcpAuth).not.toHaveBeenCalled(); + }); + + test('redirects with a friendly reconnect error when callback auth cannot complete', async () => { + mocks.mcpAuth.mockImplementation(async (provider) => { + expect('saveClientInformation' in provider).toBe(false); + await provider.invalidateCredentials('all'); + const error = new Error('invalid_client client_secret=client-secret refresh_token=refresh-token'); + Object.assign(error, { + response: { + status: 401, + body: 'client_secret=client-secret refresh_token=refresh-token', + }, + }); + throw error; + }); + + const response = await GET(createRequest()); + const location = response.headers.get('location'); + + expect(location).toBeTruthy(); + expect(location).toContain('/settings/mcpServers'); + expect(location).toContain('status=error'); + expect(new URL(location ?? '').searchParams.get('message')).toContain('Please reconnect the server'); + expect(mocks.unsafePrisma.userMcpServer.findFirst).toHaveBeenCalledWith({ + where: { + state: 'state-1', + userId: 'user-1', + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + }, + }, + }, + }); + expect(mocks.unsafePrisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + expect(mocks.logger.warn).toHaveBeenCalledWith('Failed to authorize MCP server.', { + serverId: 'server-1', + orgId: 1, + error: { + errorClass: 'Error', + oauthError: 'invalid_client', + statusCode: 401, + }, + }); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('refresh-token'); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts new file mode 100644 index 000000000..30906ba32 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts @@ -0,0 +1,165 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { env, createLogger } from '@sourcebot/shared'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +// Note: We use the raw (unscoped) prisma client here because this route handles OAuth +// redirect callbacks from external providers, so it can't go through withAuth. Session +// identity is verified via NextAuth's auth() instead, and all queries filter by userId. +import { __unsafePrisma as prisma } from '@/prisma'; +import { auth } from '@/auth'; +import { NextRequest, NextResponse } from 'next/server'; +import { getExternalMcpErrorLogFields } from '@/ee/features/mcp/externalMcpError'; +import { getMcpOAuthReturnToFromState } from '@/features/mcp/mcpOAuthReturnTo'; + +const logger = createLogger('mcp-oauth-callback'); +const reconnectMessage = 'This MCP server authorization could not be completed. Please reconnect the server.'; +const defaultMcpOAuthReturnTo = '/settings/mcpServers'; + +function createMcpOAuthRedirectUrl(returnTo: string | undefined): URL { + return new URL(returnTo ?? defaultMcpOAuthReturnTo, env.AUTH_URL); +} + +function setMcpOAuthStatusParams(url: URL, params: { status: 'connected' | 'error'; server?: string; message?: string }) { + url.searchParams.set('status', params.status); + if (params.server) { + url.searchParams.set('server', params.server); + } + if (params.message) { + url.searchParams.set('message', params.message); + } +} + +function redirectToCallbackError(message: string, returnTo?: string) { + const url = createMcpOAuthRedirectUrl(returnTo); + setMcpOAuthStatusParams(url, { status: 'error', message }); + return NextResponse.redirect(url); +} + +// eslint-disable-next-line authz/require-auth-wrapper -- OAuth redirect callback validates the active session with auth() and filters all queries by userId. +export const GET = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const session = await auth(); + if (!session?.user?.id) { + return Response.json( + { error: 'unauthorized', error_description: 'You must be logged in.' }, + { status: 401 } + ); + } + + const { searchParams } = request.nextUrl; + const oauthError = searchParams.get('error'); + const code = searchParams.get('code'); + const state = searchParams.get('state'); + const callbackReturnTo = getMcpOAuthReturnToFromState(state); + + // Handle OAuth errors (e.g., user cancelled the authorization flow). + if (oauthError) { + const url = createMcpOAuthRedirectUrl(callbackReturnTo); + const errorDescription = searchParams.get('error_description') ?? 'Authorization was cancelled or denied.'; + setMcpOAuthStatusParams(url, { status: 'error', message: errorDescription }); + return NextResponse.redirect(url); + } + + if (!code || !state) { + return Response.json( + { error: 'invalid_request', error_description: 'Missing required parameters: code, state.' }, + { status: 400 } + ); + } + + const userServer = await prisma.userMcpServer.findFirst({ + where: { + state, + userId: session.user.id, + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + }, + }, + }, + }); + + if (!userServer) { + return Response.json( + { error: 'invalid_state', error_description: 'No pending authorization found for this state.' }, + { status: 400 } + ); + } + + const orgMembership = await prisma.userToOrg.findUnique({ + where: { + orgId_userId: { + orgId: userServer.server.orgId, + userId: session.user.id, + }, + }, + }); + + if (!orgMembership) { + return Response.json( + { error: 'forbidden', error_description: 'You do not have access to this MCP server.' }, + { status: 403 } + ); + } + + const provider = new PrismaOAuthClientProvider({ + prisma, + serverId: userServer.serverId, + orgId: userServer.server.orgId, + userId: session.user.id, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + }); + + let result: Awaited>; + + try { + result = await mcpAuth(provider, { + serverUrl: new URL(userServer.server.serverUrl), + authorizationCode: code, + callbackState: state, + }); + } catch (error) { + logger.warn('Failed to authorize MCP server.', { + serverId: userServer.serverId, + orgId: userServer.server.orgId, + error: getExternalMcpErrorLogFields(error), + }); + try { + await provider.invalidateCredentials('verifier'); + } catch (cleanupError) { + logger.warn(`Failed to clear MCP OAuth verifier for user ${session.user.id}:`, cleanupError); + } + return redirectToCallbackError(reconnectMessage, callbackReturnTo); + } + + // Always clear ephemeral PKCE/state regardless of outcome to prevent replay. + try { + await provider.invalidateCredentials('verifier'); + } catch (cleanupError) { + logger.warn(`Failed to clear MCP OAuth verifier for user ${session.user.id}:`, cleanupError); + } + + if (result === 'AUTHORIZED') { + const displayName = userServer.server.name || userServer.server.serverUrl; + logger.info(`Successfully authorized MCP server ${displayName} for user ${session.user.id}.`); + const url = createMcpOAuthRedirectUrl(callbackReturnTo); + setMcpOAuthStatusParams(url, { status: 'connected', server: displayName }); + return NextResponse.redirect(url); + } + + // If auth() didn't return AUTHORIZED, something went wrong + return redirectToCallbackError('Token exchange failed', callbackReturnTo); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts new file mode 100644 index 000000000..a89f382ef --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts @@ -0,0 +1,214 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { OrgRole } from '@sourcebot/db'; +import { ErrorCode } from '@/lib/errorCodes'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + withAuth: vi.fn(), + unsafePrisma: { + userMcpServer: { + groupBy: vi.fn(), + }, + }, +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: mocks.withAuth, +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); + +const { GET } = await import('./route'); + +function createRequest() { + return new NextRequest('http://localhost/api/ee/askmcp/configuration', { method: 'GET' }); +} + +function createPrismaMock() { + return { + mcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + }, + ]), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.withAuth.mockImplementation((callback: (context: unknown) => unknown) => callback(mocks.authContext)); + mocks.unsafePrisma.userMcpServer.groupBy.mockResolvedValue([ + { + serverId: 'server-1', + _count: { _all: 2 }, + }, + ]); +}); + +describe('GET /api/ee/askmcp/configuration', () => { + test('lists approved servers with current-member saved connection counts', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(prisma.mcpServer.findMany).toHaveBeenCalledWith({ + where: { orgId: 1 }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + expect(mocks.unsafePrisma.userMcpServer.groupBy).toHaveBeenCalledWith({ + by: ['serverId'], + where: { + serverId: { in: ['server-1', 'server-2'] }, + tokens: { not: null }, + server: { orgId: 1 }, + user: { + orgs: { + some: { orgId: 1 }, + }, + }, + }, + _count: { _all: true }, + }); + expect(body).toMatchObject({ + totalSavedConnectionCount: 2, + allowedMode: 'approved_only', + isOAuthAvailable: true, + servers: [ + { + id: 'server-1', + name: 'Linear', + savedConnectionCount: 2, + }, + { + id: 'server-2', + name: 'Sentry', + savedConnectionCount: 0, + }, + ], + }); + }); + + test('rejects non-owners before the unsafe aggregate query', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.MEMBER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(403); + expect(body).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.findMany).not.toHaveBeenCalled(); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + }); + + test('rejects unauthenticated callers before checking OAuth entitlement', async () => { + mocks.withAuth.mockResolvedValue({ + statusCode: 401, + errorCode: ErrorCode.NOT_AUTHENTICATED, + message: 'Not authenticated', + }); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(401); + expect(body).toMatchObject({ + errorCode: ErrorCode.NOT_AUTHENTICATED, + }); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + }); + + test('allows entitled owners to list cleanup data when OAuth is unsupported', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + mocks.hasEntitlement.mockResolvedValue(false); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + isOAuthAvailable: false, + totalSavedConnectionCount: 2, + servers: [ + { + id: 'server-1', + savedConnectionCount: 2, + }, + { + id: 'server-2', + savedConnectionCount: 0, + }, + ], + }); + expect(mocks.withAuth).toHaveBeenCalled(); + expect(prisma.mcpServer.findMany).toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).toHaveBeenCalled(); + }); + + test('skips the unsafe aggregate query when there are no approved servers', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findMany.mockResolvedValue([]); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + expect(body).toEqual({ + servers: [], + totalSavedConnectionCount: 0, + allowedMode: 'approved_only', + isOAuthAvailable: true, + }); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts new file mode 100644 index 000000000..303418a82 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts @@ -0,0 +1,72 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { hasEntitlement } from '@/lib/entitlements'; +import { withAuth } from '@/middleware/withAuth'; +import { withMinimumOrgRole } from '@/middleware/withMinimumOrgRole'; +import { __unsafePrisma } from '@/prisma'; +import { getMcpFaviconUrl } from '@/ee/features/mcp/utils'; +import type { GetMcpConfigurationResponse } from '@/ee/features/mcp/types'; +import { OrgRole } from '@sourcebot/db'; +import type { NextRequest } from 'next/server'; + +export const GET = apiHandler(async (_request: NextRequest) => { + const result = await withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { + const isOAuthAvailable = await hasEntitlement('oauth'); + + const orgServers = await prisma.mcpServer.findMany({ + where: { orgId: org.id }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + + const serverIds = orgServers.map((server) => server.id); + const connectionCounts = serverIds.length === 0 + ? [] + : await __unsafePrisma.userMcpServer.groupBy({ + by: ['serverId'], + where: { + serverId: { in: serverIds }, + tokens: { not: null }, + server: { orgId: org.id }, + user: { + orgs: { + some: { orgId: org.id }, + }, + }, + }, + _count: { _all: true }, + }); + const countByServerId = new Map( + connectionCounts.map((row) => [row.serverId, row._count._all]), + ); + + const servers = orgServers.map((server) => { + const savedConnectionCount = countByServerId.get(server.id) ?? 0; + return { + ...server, + faviconUrl: getMcpFaviconUrl(server.serverUrl, server.name), + savedConnectionCount, + }; + }); + + return { + servers, + totalSavedConnectionCount: servers.reduce((total, server) => total + server.savedConnectionCount, 0), + allowedMode: 'approved_only', + isOAuthAvailable, + }; + })); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts new file mode 100644 index 000000000..6a379c6de --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts @@ -0,0 +1,250 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { McpServerClientInfoSource } from '@sourcebot/db'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + mcpAuth: vi.fn(), + unsafePrisma: { + $transaction: vi.fn(), + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + env: { + AUTH_URL: 'https://sourcebot.example.com', + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, + createLogger: () => mocks.logger, + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), +})); +vi.mock('@ai-sdk/mcp', () => ({ + auth: mocks.mcpAuth, +})); + +const { POST } = await import('./route'); +const { getMcpOAuthReturnToFromState } = await import('@/features/mcp/mcpOAuthReturnTo'); + +function createRequest(body: { serverId: string; returnTo?: string } = { serverId: 'server-1' }) { + return new NextRequest('http://localhost/api/ee/askmcp/connect', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(body), + }); +} + +function createPrismaMock() { + return { + mcpServer: { + findFirst: vi.fn().mockResolvedValue({ + id: 'server-1', + serverUrl: 'https://mcp.linear.app/mcp', + }), + }, + userMcpServer: { + upsert: vi.fn().mockResolvedValue({ userId: 'user-1', serverId: 'server-1' }), + }, + }; +} + +function createTransactionMock() { + return { + $queryRaw: vi.fn().mockResolvedValue([{ id: 'server-1' }]), + mcpServer: { + findFirst: vi.fn(), + updateMany: vi.fn().mockResolvedValue({ count: 1 }), + }, + userMcpServer: { + findUnique: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); +}); + +describe('POST /api/ee/askmcp/connect', () => { + test('upserts a nameless user row and performs DCR-capable auth under a row lock', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider, options) => { + expect('saveClientInformation' in provider).toBe(true); + expect(provider.saveClientInformation).toEqual(expect.any(Function)); + expect(options.fetchFn).toEqual(expect.any(Function)); + + await provider.saveClientInformation({ client_id: 'client-1' }); + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest()); + const body = await response.json(); + + expect(prisma.userMcpServer.upsert).toHaveBeenCalledWith({ + where: { + userId_serverId: { + userId: 'user-1', + serverId: 'server-1', + }, + }, + create: { + userId: 'user-1', + serverId: 'server-1', + }, + update: {}, + }); + expect(mocks.unsafePrisma.$transaction).toHaveBeenCalledWith( + expect.any(Function), + { + maxWait: 10000, + timeout: 10000, + }, + ); + expect(tx.$queryRaw).toHaveBeenCalledOnce(); + expect(tx.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { id: 'server-1', orgId: 1 }, + data: { + clientInfo: 'encrypted:{"client_id":"client-1"}', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + }); + + test('encodes a safe return path into OAuth state', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider) => { + const state = await provider.state(); + expect(getMcpOAuthReturnToFromState(state)).toBe('/chat'); + await provider.saveState(state); + + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest({ serverId: 'server-1', returnTo: '/chat' })); + const body = await response.json(); + + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + expect(tx.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + state: expect.stringContaining('sourcebot_mcp.'), + }, + }); + }); + + test('ignores unsafe return paths', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider) => { + const state = await provider.state(); + expect(getMcpOAuthReturnToFromState(state)).toBeUndefined(); + await provider.saveState(state); + + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest({ serverId: 'server-1', returnTo: 'https://evil.example.com/chat' })); + const body = await response.json(); + + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + expect(tx.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + state: expect.not.stringContaining('sourcebot_mcp.'), + }, + }); + }); + + test('sanitizes external OAuth errors before logging', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async () => { + const error = new Error('invalid_client client_secret=client-secret refresh_token=refresh-token'); + Object.assign(error, { + response: { + status: 400, + body: 'client_secret=client-secret refresh_token=refresh-token', + }, + }); + throw error; + }); + + const response = await POST(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(502); + expect(body).toMatchObject({ + message: 'Could not start MCP authorization.', + }); + expect(mocks.logger.warn).toHaveBeenCalledWith('Failed to start MCP authorization.', { + serverId: 'server-1', + orgId: 1, + error: { + errorClass: 'Error', + oauthError: 'invalid_client', + statusCode: 400, + }, + }); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('refresh-token'); + expect(JSON.stringify(mocks.logger.error.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.error.mock.calls)).not.toContain('refresh-token'); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts new file mode 100644 index 000000000..89f02381a --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts @@ -0,0 +1,154 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { withAuth } from '@/middleware/withAuth'; +import { sew } from '@/middleware/sew'; +import { isServiceError } from '@/lib/utils'; +import { serviceErrorResponse, notFound, requestBodySchemaValidationError, ServiceErrorException } from '@/lib/serviceError'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { NextRequest } from 'next/server'; +import { z } from 'zod'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { ConnectMcpResponse } from "@/app/api/(server)/ee/askmcp/connect/types"; +import { createLogger, env } from "@sourcebot/shared"; +import { __unsafePrisma } from '@/prisma'; +import { getExternalMcpErrorLogFields } from '@/ee/features/mcp/externalMcpError'; +import { ErrorCode } from '@/lib/errorCodes'; +import { StatusCodes } from 'http-status-codes'; +import { normalizeMcpOAuthReturnTo } from '@/features/mcp/mcpOAuthReturnTo'; + +const bodySchema = z.object({ + serverId: z.string(), + returnTo: z.string().optional(), +}); +const logger = createLogger('mcp-connect'); +const MCP_AUTH_FETCH_TIMEOUT_MS = Math.min(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS, 30000); +const MCP_AUTH_TRANSACTION_MAX_WAIT_MS = 10000; +const MCP_AUTH_TRANSACTION_TIMEOUT_MS = MCP_AUTH_FETCH_TIMEOUT_MS + 5000; + +function createTimeoutFetch(timeoutMs: number): typeof fetch { + return async (input, init) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const signal = init?.signal + ? AbortSignal.any([init.signal, timeoutSignal]) + : timeoutSignal; + + return fetch(input, { + ...init, + signal, + }); + }; +} + +export const POST = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const body = await request.json(); + const parsed = bodySchema.safeParse(body); + if (!parsed.success) { + return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); + } + + const result = await sew(() => + withAuth(async ({ user, org, prisma }) => { + const callbackReturnTo = normalizeMcpOAuthReturnTo(parsed.data.returnTo); + const mcpServer = await prisma.mcpServer.findFirst({ + where: { id: parsed.data.serverId, orgId: org.id }, + select: { + id: true, + serverUrl: true, + }, + }); + if (!mcpServer) { + return notFound('MCP server not found'); + } + + await prisma.userMcpServer.upsert({ + where: { + userId_serverId: { + userId: user.id, + serverId: mcpServer.id, + }, + }, + create: { + userId: user.id, + serverId: mcpServer.id, + }, + update: {}, + }); + + const connectResult = await __unsafePrisma.$transaction(async (tx) => { + const lockedRows = await tx.$queryRaw<{ id: string }[]>` + SELECT id + FROM "McpServer" + WHERE id = ${mcpServer.id} AND "orgId" = ${org.id} + FOR UPDATE + `; + + if (lockedRows.length === 0) { + throw new ServiceErrorException(notFound('MCP server not found')); + } + + const provider = new PrismaOAuthClientProvider({ + prisma: tx, + clientInvalidationPrisma: tx, + serverId: mcpServer.id, + orgId: org.id, + userId: user.id, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + callbackReturnTo, + allowClientRegistration: true, + }); + + let authResult: Awaited>; + try { + authResult = await mcpAuth(provider, { + serverUrl: new URL(mcpServer.serverUrl), + fetchFn: createTimeoutFetch(MCP_AUTH_FETCH_TIMEOUT_MS), + }); + } catch (error) { + logger.warn('Failed to start MCP authorization.', { + serverId: mcpServer.id, + orgId: org.id, + error: getExternalMcpErrorLogFields(error), + }); + throw new ServiceErrorException({ + statusCode: StatusCodes.BAD_GATEWAY, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Could not start MCP authorization.', + }); + } + + return { + authResult, + authorizationUrl: provider.authorizationUrl ?? null, + }; + }, { + maxWait: MCP_AUTH_TRANSACTION_MAX_WAIT_MS, + timeout: MCP_AUTH_TRANSACTION_TIMEOUT_MS, + }); + + if (connectResult.authResult === 'AUTHORIZED') { + // Already has valid tokens (e.g., refreshed) + return { authorizationUrl: null } satisfies ConnectMcpResponse; + } + + if (!connectResult.authorizationUrl) { + throw new Error('MCP auth returned REDIRECT without an authorization URL'); + } + + return { authorizationUrl: connectResult.authorizationUrl } satisfies ConnectMcpResponse; + }) + ); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts new file mode 100644 index 000000000..80281ae17 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts @@ -0,0 +1,4 @@ +export interface ConnectMcpResponse { + /** The external OAuth authorization URL the browser should navigate to. Null if already authorized. */ + authorizationUrl: string | null; +} \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts new file mode 100644 index 000000000..5fe917f02 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts @@ -0,0 +1,128 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@sourcebot/shared', () => ({ + decryptOAuthToken: vi.fn((value: string) => value), +})); + +const { GET } = await import('./route'); + +function createRequest() { + return new NextRequest('http://localhost/api/ee/askmcp/servers', { method: 'GET' }); +} + +function createPrismaMock() { + return { + mcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + }, + { + id: 'server-3', + name: 'GitHub', + sanitizedName: 'github', + serverUrl: 'https://api.githubcopilot.com/mcp', + }, + ]), + }, + userMcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + serverId: 'server-1', + tokens: JSON.stringify({ access_token: 'token', token_type: 'Bearer' }), + tokensExpiresAt: null, + }, + { + serverId: 'server-3', + tokens: JSON.stringify({ access_token: 'expired-token', token_type: 'Bearer' }), + tokensExpiresAt: new Date('2020-01-01T00:00:00.000Z'), + }, + ]), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); +}); + +describe('GET /api/ee/askmcp/servers', () => { + test('lists org servers and merges only the caller token status', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(prisma.mcpServer.findMany).toHaveBeenCalledWith({ + where: { orgId: 1 }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + expect(prisma.userMcpServer.findMany).toHaveBeenCalledWith({ + where: { userId: 'user-1' }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + expect(body).toMatchObject([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + isConnected: true, + isAuthExpired: false, + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + isConnected: false, + isAuthExpired: false, + }, + { + id: 'server-3', + name: 'GitHub', + sanitizedName: 'github', + isConnected: false, + isAuthExpired: true, + }, + ]); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts new file mode 100644 index 000000000..8fe277379 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts @@ -0,0 +1,96 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { withAuth } from '@/middleware/withAuth'; +import { hasEntitlement } from '@/lib/entitlements'; +import { decryptOAuthToken } from '@sourcebot/shared'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import type { OAuthTokens } from '@ai-sdk/mcp'; +import { getMcpFaviconUrl } from '@/ee/features/mcp/utils'; +import type { NextRequest } from 'next/server'; + +export interface McpServerWithStatus { + id: string; + name: string; + serverUrl: string; + sanitizedName: string; + faviconUrl: string | undefined; + isConnected: boolean; + isAuthExpired: boolean; +} + +export type GetMcpServersResponse = McpServerWithStatus[]; + +export const GET = apiHandler(async (_request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const result = await withAuth(async ({ org, user, prisma }) => { + const orgServers = await prisma.mcpServer.findMany({ + where: { orgId: org.id }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + + const userServers = await prisma.userMcpServer.findMany({ + where: { userId: user.id }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + const userServerByServerId = new Map(userServers.map((us) => [us.serverId, us])); + + return orgServers.map((server): McpServerWithStatus => { + const userServer = userServerByServerId.get(server.id); + const faviconUrl = getMcpFaviconUrl(server.serverUrl, server.name); + + let isConnected = false; + let isAuthExpired = false; + + if (userServer?.tokens) { + try { + const decrypted = decryptOAuthToken(userServer.tokens); + if (decrypted) { + const tokens: OAuthTokens = JSON.parse(decrypted); + if (tokens.refresh_token || !userServer.tokensExpiresAt) { + isConnected = true; + } else if (new Date() > userServer.tokensExpiresAt) { + isAuthExpired = true; + } else { + isConnected = true; + } + } + } catch { + // treat as not connected if decryption fails + } + } + + return { + id: server.id, + name: server.name, + serverUrl: server.serverUrl, + sanitizedName: server.sanitizedName, + faviconUrl, + isConnected, + isAuthExpired, + }; + }); + }); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/ee/features/mcp/actions.test.ts b/packages/web/src/ee/features/mcp/actions.test.ts new file mode 100644 index 000000000..a37a84e21 --- /dev/null +++ b/packages/web/src/ee/features/mcp/actions.test.ts @@ -0,0 +1,386 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { McpServerClientInfoSource, OrgRole } from '@sourcebot/db'; +import { ErrorCode } from '@/lib/errorCodes'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + headers: vi.fn(async () => new Headers({ + host: 'sourcebot.example.com', + origin: 'https://sourcebot.example.com', + 'x-forwarded-proto': 'https', + })), + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + env: { + AUTH_URL: 'https://sourcebot.example.com', + NODE_ENV: 'production', + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, + logger: { + error: vi.fn(), + }, + unsafePrisma: { + mcpServer: { + deleteMany: vi.fn(), + }, + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('next/headers', () => ({ + headers: mocks.headers, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => mocks.logger, + encryptOAuthToken: mocks.encryptOAuthToken, + env: mocks.env, +})); + +const { createMcpServer, createStaticOAuthMcpServer, deleteMcpServer } = await import('./actions'); + +function createPrismaMock() { + return { + mcpServer: { + findUnique: vi.fn().mockResolvedValue(null), + findFirst: vi.fn().mockResolvedValue(null), + create: vi.fn().mockImplementation(async ({ data }) => ({ + id: 'server-1', + name: data.name, + sanitizedName: data.sanitizedName, + serverUrl: data.serverUrl, + })), + }, + }; +} + +function setAuthContext(role: OrgRole, prisma = createPrismaMock()) { + mocks.authContext = { + org: { id: 1 }, + role, + prisma, + }; + return prisma; +} + +function createStaticOAuthRequest(overrides: Partial<{ + name: string; + serverUrl: string; + clientId: string; + clientSecret: string; +}> = {}) { + return { + name: 'Slack', + serverUrl: 'https://mcp.slack.com/mcp', + clientId: 'client-id', + clientSecret: 'client-secret', + ...overrides, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.headers.mockResolvedValue(new Headers({ + host: 'sourcebot.example.com', + origin: 'https://sourcebot.example.com', + 'x-forwarded-proto': 'https', + })); + mocks.encryptOAuthToken.mockImplementation((text: string | null | undefined) => text ? `encrypted:${text}` : undefined); + mocks.env.AUTH_URL = 'https://sourcebot.example.com'; + mocks.env.NODE_ENV = 'production'; + mocks.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS = 5000; +}); + +describe('createMcpServer', () => { + test('owners add an org MCP server without dynamic client information', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createMcpServer(' Linear ', ' https://mcp.linear.app/mcp '); + + expect(result).toEqual({ + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }); + expect(prisma.mcpServer.create).toHaveBeenCalledWith({ + data: { + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + clientInfo: null, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + orgId: 1, + }, + }); + }); + + test('members cannot add org MCP servers', async () => { + const prisma = setAuthContext(OrgRole.MEMBER); + + const result = await createMcpServer('Linear', 'https://mcp.linear.app/mcp'); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('owners cannot add org MCP servers when OAuth is unsupported', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.hasEntitlement.mockResolvedValue(false); + + const result = await createMcpServer('Linear', 'https://mcp.linear.app/mcp'); + + expect(result).toMatchObject({ + statusCode: 403, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); +}); + +describe('createStaticOAuthMcpServer', () => { + test('owners add a static OAuth MCP server with encrypted client information', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer({ + name: ' Slack ', + serverUrl: 'https://mcp.slack.com/mcp', + clientId: ' client-id ', + clientSecret: ' client-secret ', + }); + + expect(mocks.encryptOAuthToken).toHaveBeenCalledWith(JSON.stringify({ + client_id: 'client-id', + client_secret: 'client-secret', + })); + expect(prisma.mcpServer.create).toHaveBeenCalledWith({ + data: { + name: 'Slack', + sanitizedName: 'slack', + serverUrl: 'https://mcp.slack.com/mcp', + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + orgId: 1, + }, + }); + expect(JSON.stringify(result)).not.toContain('client-secret'); + expect(result).toEqual({ + id: 'server-1', + name: 'Slack', + sanitizedName: 'slack', + serverUrl: 'https://mcp.slack.com/mcp', + }); + }); + + test('members cannot add static OAuth MCP servers', async () => { + const prisma = setAuthContext(OrgRole.MEMBER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('rejects static OAuth credentials when production AUTH_URL is not HTTPS', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.env.AUTH_URL = 'http://sourcebot.example.com'; + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Static OAuth client credentials require HTTPS in production.', + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth credentials over insecure production requests', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.headers.mockResolvedValue(new Headers({ + host: 'sourcebot.example.com', + origin: 'http://sourcebot.example.com', + 'x-forwarded-proto': 'http', + })); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Static OAuth client credentials require HTTPS in production.', + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('does not echo client secrets in validation errors', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer({ + name: 'Slack', + serverUrl: 'not-a-url', + clientId: 'client-id', + clientSecret: 'client-secret', + }); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + }); + expect(JSON.stringify(result)).not.toContain('client-secret'); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('rejects static OAuth servers with non-HTTPS server URLs', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + serverUrl: 'http://mcp.slack.com/mcp', + })); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid server URL. Must be a valid HTTPS URL.', + }); + expect(prisma.mcpServer.findUnique).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with fewer than 3 alphanumeric name characters', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + name: '!!a!', + })); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Server name must contain at least 3 alphanumeric characters.', + }); + expect(prisma.mcpServer.findUnique).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with a duplicate URL', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + prisma.mcpServer.findUnique.mockResolvedValue({ id: 'existing-server' }); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 409, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'An MCP server with URL "https://mcp.slack.com/mcp" already exists.', + }); + expect(prisma.mcpServer.findFirst).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with a duplicate sanitized name', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + prisma.mcpServer.findFirst.mockResolvedValue({ id: 'existing-server' }); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + name: 'Slack!!!', + })); + + expect(result).toMatchObject({ + statusCode: 409, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'An MCP server with a similar name already exists. Please choose a more distinct name.', + }); + expect(prisma.mcpServer.findUnique).toHaveBeenCalledWith({ + where: { + serverUrl_orgId: { + serverUrl: 'https://mcp.slack.com/mcp', + orgId: 1, + }, + }, + select: { id: true }, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers when client credential encryption fails', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.encryptOAuthToken.mockReturnValue(undefined); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 500, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Failed to store OAuth client credentials.', + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); +}); + +describe('deleteMcpServer', () => { + test('owners delete through the narrowly scoped unsafe client', async () => { + setAuthContext(OrgRole.OWNER); + mocks.unsafePrisma.mcpServer.deleteMany.mockResolvedValue({ count: 1 }); + + await expect(deleteMcpServer('server-1')).resolves.toEqual({ success: true }); + expect(mocks.unsafePrisma.mcpServer.deleteMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + }, + }); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + }); + + test('members cannot delete org MCP servers', async () => { + setAuthContext(OrgRole.MEMBER); + + const result = await deleteMcpServer('server-1'); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(mocks.unsafePrisma.mcpServer.deleteMany).not.toHaveBeenCalled(); + }); + + test('owners can delete org MCP servers when OAuth is unsupported', async () => { + setAuthContext(OrgRole.OWNER); + mocks.hasEntitlement.mockResolvedValue(false); + mocks.unsafePrisma.mcpServer.deleteMany.mockResolvedValue({ count: 1 }); + + await expect(deleteMcpServer('server-1')).resolves.toEqual({ success: true }); + + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServer.deleteMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + }, + }); + }); +}); diff --git a/packages/web/src/ee/features/mcp/actions.ts b/packages/web/src/ee/features/mcp/actions.ts new file mode 100644 index 000000000..ebe2470c6 --- /dev/null +++ b/packages/web/src/ee/features/mcp/actions.ts @@ -0,0 +1,356 @@ +'use server'; + +import { sew } from '@/middleware/sew'; +import { ErrorCode } from '@/lib/errorCodes'; +import { requestBodySchemaValidationError, ServiceError } from '@/lib/serviceError'; +import { withAuth } from '@/middleware/withAuth'; +import { withMinimumOrgRole } from '@/middleware/withMinimumOrgRole'; +import { __unsafePrisma } from '@/prisma'; +import { isServiceError } from '@/lib/utils'; +import { McpServerClientInfoSource, OrgRole, type PrismaClient } from '@sourcebot/db'; +import { StatusCodes } from 'http-status-codes'; +import { z } from 'zod'; +import { sanitizeMcpServerName } from './utils'; +import { hasEntitlement } from '@/lib/entitlements'; +import { oauthNotSupported } from './errors'; +import { checkMcpServerDcrSupport } from './dcrDiscovery'; +import { encryptOAuthToken, env } from '@sourcebot/shared'; +import { headers } from 'next/headers'; + +const MCP_DCR_DISCOVERY_TIMEOUT_MS = Math.min(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS, 10000); +const createStaticOAuthMcpServerSchema = z.object({ + name: z.string().trim().min(1), + serverUrl: z.string().trim().url(), + clientId: z.string().trim().min(1), + clientSecret: z.string().trim().min(1), +}); + +export type CreateStaticOAuthMcpServerRequest = z.infer; + +export interface CreateStaticOAuthMcpServerResponse { + id: string; + name: string; + sanitizedName: string; + serverUrl: string; +} + +type McpServerPrismaClient = Pick; + +interface PreparedMcpServerCreate { + displayName: string; + normalizedServerUrl: string; + sanitizedName: string; +} + +function createTimeoutFetch(timeoutMs: number): typeof fetch { + return async (input, init) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const signal = init?.signal + ? AbortSignal.any([init.signal, timeoutSignal]) + : timeoutSignal; + + return fetch(input, { + ...init, + signal, + }); + }; +} + +function invalidRequest(message: string): ServiceError { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message, + }; +} + +function getFirstHeaderValue(value: string | null): string | undefined { + return value?.split(',')[0]?.trim().toLowerCase(); +} + +function getHeaderUrlProtocol(value: string | null, host: string | undefined): string | undefined { + if (!value || !host) { + return undefined; + } + + try { + const url = new URL(value); + return url.host === host ? url.protocol : undefined; + } catch { + return undefined; + } +} + +async function assertHttpsInProduction(): Promise { + if (env.NODE_ENV !== 'production') { + return undefined; + } + + const requestHeaders = await headers(); + const publicAuthUrlIsHttps = new URL(env.AUTH_URL).protocol === 'https:'; + const host = getFirstHeaderValue(requestHeaders.get('x-forwarded-host')) + ?? getFirstHeaderValue(requestHeaders.get('host')); + const originProtocol = getHeaderUrlProtocol(requestHeaders.get('origin'), host); + const refererProtocol = getHeaderUrlProtocol(requestHeaders.get('referer'), host); + const requestIsHttps = getFirstHeaderValue(requestHeaders.get('x-forwarded-proto')) === 'https' + || getFirstHeaderValue(requestHeaders.get('x-forwarded-ssl')) === 'on' + || originProtocol === 'https:' + || refererProtocol === 'https:'; + + if (publicAuthUrlIsHttps && requestIsHttps) { + return undefined; + } + + return invalidRequest('Static OAuth client credentials require HTTPS in production.'); +} + +async function prepareMcpServerCreate({ + prisma, + orgId, + name, + serverUrl, +}: { + prisma: McpServerPrismaClient; + orgId: number; + name: string; + serverUrl: string; +}): Promise { + const displayName = name.trim(); + const normalizedServerUrl = serverUrl.trim(); + const urlResult = z.string().url().safeParse(normalizedServerUrl); + const protocol = urlResult.success ? new URL(normalizedServerUrl).protocol : undefined; + if (!urlResult.success || protocol !== 'https:') { + return invalidRequest('Invalid server URL. Must be a valid HTTPS URL.'); + } + + const sanitizedName = sanitizeMcpServerName(displayName); + const alphanumericCount = (sanitizedName.match(/[a-z0-9]/g) ?? []).length; + if (alphanumericCount < 3) { + return invalidRequest('Server name must contain at least 3 alphanumeric characters.'); + } + + const existingServer = await prisma.mcpServer.findUnique({ + where: { + serverUrl_orgId: { + serverUrl: normalizedServerUrl, + orgId, + }, + }, + select: { id: true }, + }); + if (existingServer) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: `An MCP server with URL "${normalizedServerUrl}" already exists.`, + } satisfies ServiceError; + } + + const existingName = await prisma.mcpServer.findFirst({ + where: { + orgId, + sanitizedName, + }, + select: { id: true }, + }); + if (existingName) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'An MCP server with a similar name already exists. Please choose a more distinct name.', + } satisfies ServiceError; + } + + return { + displayName, + normalizedServerUrl, + sanitizedName, + }; +} + +export const checkMcpServerDynamicClientRegistration = async (serverUrl: string) => sew(() => + withAuth(async ({ role }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const normalizedServerUrl = serverUrl.trim(); + const urlResult = z.string().url().safeParse(normalizedServerUrl); + const protocol = urlResult.success ? new URL(normalizedServerUrl).protocol : undefined; + if (!urlResult.success || protocol !== 'https:') { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid server URL. Must be a valid HTTPS URL.', + } satisfies ServiceError; + } + + try { + return await checkMcpServerDcrSupport( + normalizedServerUrl, + createTimeoutFetch(MCP_DCR_DISCOVERY_TIMEOUT_MS), + ); + } catch { + return { + statusCode: StatusCodes.BAD_GATEWAY, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Could not check whether this MCP server supports dynamic client registration.', + } satisfies ServiceError; + } + }))); + +export const createStaticOAuthMcpServer = async ( + body: CreateStaticOAuthMcpServerRequest, +) => { + const parsed = createStaticOAuthMcpServerSchema.safeParse(body); + if (!parsed.success) { + return requestBodySchemaValidationError(parsed.error); + } + + return sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const httpsError = await assertHttpsInProduction(); + if (httpsError) { + return httpsError; + } + + const preparedServer = await prepareMcpServerCreate({ + prisma, + orgId: org.id, + name: parsed.data.name, + serverUrl: parsed.data.serverUrl, + }); + if (isServiceError(preparedServer)) { + return preparedServer; + } + + const clientInfo = encryptOAuthToken(JSON.stringify({ + client_id: parsed.data.clientId, + client_secret: parsed.data.clientSecret, + })); + if (!clientInfo) { + return { + statusCode: StatusCodes.INTERNAL_SERVER_ERROR, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Failed to store OAuth client credentials.', + } satisfies ServiceError; + } + + const mcpServer = await prisma.mcpServer.create({ + data: { + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: preparedServer.normalizedServerUrl, + clientInfo, + clientInfoSource: McpServerClientInfoSource.STATIC, + orgId: org.id, + }, + }); + + return { + id: mcpServer.id, + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: mcpServer.serverUrl, + }; + }))); +} + +export const createMcpServer = async (name: string, serverUrl: string) => sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const preparedServer = await prepareMcpServerCreate({ + prisma, + orgId: org.id, + name, + serverUrl, + }); + if (isServiceError(preparedServer)) { + return preparedServer; + } + + const mcpServer = await prisma.mcpServer.create({ + data: { + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: preparedServer.normalizedServerUrl, + clientInfo: null, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + orgId: org.id, + }, + }); + + return { + id: mcpServer.id, + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: mcpServer.serverUrl, + }; + }))); + +export const deleteMcpServer = async (serverId: string) => sew(() => + withAuth(async ({ org, role }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + const result = await __unsafePrisma.mcpServer.deleteMany({ + where: { + id: serverId, + orgId: org.id, + }, + }); + + if (result.count === 0) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'MCP server not found', + } satisfies ServiceError; + } + + return { success: true }; + }))); + +export const disconnectMcpServer = async (serverId: string) => sew(() => + withAuth(async ({ org, user }) => { + const server = await __unsafePrisma.mcpServer.findFirst({ + where: { + id: serverId, + orgId: org.id, + }, + select: { id: true }, + }); + + if (!server) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'MCP server not found', + } satisfies ServiceError; + } + + const result = await __unsafePrisma.userMcpServer.deleteMany({ + where: { + serverId, + userId: user.id, + }, + }); + + if (result.count === 0) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'No connection found for this MCP server.', + } satisfies ServiceError; + } + + return { success: true }; + })); diff --git a/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx b/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx new file mode 100644 index 000000000..417392734 --- /dev/null +++ b/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx @@ -0,0 +1,35 @@ +'use client'; + +import { LoadingButton } from '@/components/ui/loading-button'; +import { ExternalLink, PlusIcon } from 'lucide-react'; +import type { ButtonProps } from '@/components/ui/button'; +import { useConnectMcp } from '@/ee/features/mcp/hooks/useConnectMcp'; + +interface ConnectMcpButtonProps { + serverId: string; + isConnected?: boolean; + isAuthExpired?: boolean; + size?: ButtonProps['size']; +} + +export function ConnectMcpButton({ serverId, isConnected, isAuthExpired, size }: ConnectMcpButtonProps) { + const { connect, loadingServerId } = useConnectMcp(); + const loading = loadingServerId === serverId; + + const isSuggested = !isConnected && !isAuthExpired; + const buttonLabel = isSuggested ? "Connect" : "Reconnect"; + const buttonVariant = isConnected ? "outline" as const : undefined; + + return ( + connect(serverId)} + loading={loading} + variant={buttonVariant} + size={size} + > + {isSuggested && } + {buttonLabel} + {!isSuggested && } + + ); +} diff --git a/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx b/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx new file mode 100644 index 000000000..2220fc516 --- /dev/null +++ b/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx @@ -0,0 +1,24 @@ +'use client'; + +import { Plug } from "lucide-react"; +import { useState } from "react"; + +interface McpFaviconProps { + faviconUrl: string | undefined; + className?: string; +} + +export const McpFavicon = ({ faviconUrl, className = "w-4 h-4" }: McpFaviconProps) => { + const [failed, setFailed] = useState(false); + if (faviconUrl && !failed) { + return ( + setFailed(true)} + className={`${className} flex-shrink-0`} + alt="" + /> + ); + } + return ; +}; \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/dcrDiscovery.test.ts b/packages/web/src/ee/features/mcp/dcrDiscovery.test.ts new file mode 100644 index 000000000..194a2a815 --- /dev/null +++ b/packages/web/src/ee/features/mcp/dcrDiscovery.test.ts @@ -0,0 +1,217 @@ +import { describe, expect, test, vi } from 'vitest'; +import { checkMcpServerDcrSupport } from './dcrDiscovery'; + +function jsonResponse(body: unknown) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); +} + +function notFoundResponse() { + return new Response('Not found', { status: 404 }); +} + +function deferredResponse() { + let resolve!: (response: Response) => void; + const promise = new Promise((resolvePromise) => { + resolve = resolvePromise; + }); + + return { promise, resolve }; +} + +describe('checkMcpServerDcrSupport', () => { + test('returns supported when authorization server metadata advertises a registration endpoint', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.example.com/.well-known/oauth-protected-resource/mcp') { + return jsonResponse({ authorization_servers: ['https://auth.example.com'] }); + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return jsonResponse({ registration_endpoint: 'https://auth.example.com/register' }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: true, + authorizationServerUrl: 'https://auth.example.com', + registrationEndpoint: 'https://auth.example.com/register', + }); + }); + + test('returns unsupported when authorization server metadata does not advertise a registration endpoint', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.slack.com/.well-known/oauth-protected-resource') { + return jsonResponse({ authorization_servers: ['https://mcp.slack.com'] }); + } + if (url === 'https://mcp.slack.com/.well-known/oauth-authorization-server') { + return jsonResponse({ + authorization_endpoint: 'https://slack.com/oauth/v2_user/authorize', + token_endpoint: 'https://slack.com/api/oauth.v2.user.access', + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.slack.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: false, + isKnown: true, + authorizationServerUrl: 'https://mcp.slack.com', + }); + }); + + test('falls back to the resource metadata URL from a bearer challenge', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return jsonResponse({ registration_endpoint: 'https://auth.example.com/register' }); + } + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata="https://metadata.example.com/oauth-protected-resource"', + }, + }); + } + if (url === 'https://metadata.example.com/oauth-protected-resource') { + return jsonResponse({ authorization_servers: ['https://auth.example.com'] }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + const result = await checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock); + + expect(result.supportsDcr).toBe(true); + expect(result.isKnown).toBe(true); + }); + + test('ignores non-bearer authenticate challenges', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Basic realm="mcp"', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('ignores malformed bearer resource metadata URLs', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata="not a url"', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('ignores bearer resource metadata parameters without quotes', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata=https://metadata.example.com/oauth-protected-resource', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('starts authorization server metadata candidate requests concurrently while preserving priority', async () => { + const pathScopedOAuthMetadata = deferredResponse(); + const rootOAuthMetadata = deferredResponse(); + const pathScopedOidcMetadata = deferredResponse(); + const nestedOidcMetadata = deferredResponse(); + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.example.com/.well-known/oauth-protected-resource/mcp') { + return jsonResponse({ authorization_servers: ['https://auth.example.com/tenant'] }); + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server/tenant') { + return pathScopedOAuthMetadata.promise; + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return rootOAuthMetadata.promise; + } + if (url === 'https://auth.example.com/.well-known/openid-configuration/tenant') { + return pathScopedOidcMetadata.promise; + } + if (url === 'https://auth.example.com/tenant/.well-known/openid-configuration') { + return nestedOidcMetadata.promise; + } + return notFoundResponse(); + }); + + const resultPromise = checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock as unknown as typeof fetch); + await vi.waitFor(() => { + const requestedUrls = fetchMock.mock.calls.map(([input]) => input.toString()); + + expect(requestedUrls).toContain('https://auth.example.com/.well-known/oauth-authorization-server/tenant'); + expect(requestedUrls).toContain('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(requestedUrls).toContain('https://auth.example.com/.well-known/openid-configuration/tenant'); + expect(requestedUrls).toContain('https://auth.example.com/tenant/.well-known/openid-configuration'); + }); + + rootOAuthMetadata.resolve(jsonResponse({ registration_endpoint: 'https://auth.example.com/register' })); + pathScopedOidcMetadata.resolve(notFoundResponse()); + nestedOidcMetadata.resolve(notFoundResponse()); + await Promise.resolve(); + + pathScopedOAuthMetadata.resolve(notFoundResponse()); + + await expect(resultPromise).resolves.toEqual({ + supportsDcr: true, + isKnown: true, + authorizationServerUrl: 'https://auth.example.com/tenant', + registrationEndpoint: 'https://auth.example.com/register', + }); + }); +}); diff --git a/packages/web/src/ee/features/mcp/dcrDiscovery.ts b/packages/web/src/ee/features/mcp/dcrDiscovery.ts new file mode 100644 index 000000000..286883d50 --- /dev/null +++ b/packages/web/src/ee/features/mcp/dcrDiscovery.ts @@ -0,0 +1,206 @@ +import { z } from 'zod'; + +const MCP_PROTOCOL_VERSION = '2025-11-25'; + +const protectedResourceMetadataSchema = z.object({ + authorization_servers: z.array(z.string().url()).optional(), +}).passthrough(); + +const authorizationServerMetadataSchema = z.object({ + registration_endpoint: z.string().url().optional(), +}).passthrough(); + +export interface McpServerDcrSupport { + supportsDcr: boolean; + isKnown: boolean; + authorizationServerUrl?: string; + registrationEndpoint?: string; +} + +function getMetadataHeaders() { + return { + Accept: 'application/json', + 'MCP-Protocol-Version': MCP_PROTOCOL_VERSION, + }; +} + +function buildProtectedResourceMetadataUrls(serverUrl: URL): URL[] { + const urls: URL[] = []; + const pathname = serverUrl.pathname.endsWith('/') + ? serverUrl.pathname.slice(0, -1) + : serverUrl.pathname; + + if (pathname && pathname !== '/') { + urls.push(new URL(`/.well-known/oauth-protected-resource${pathname}`, serverUrl.origin)); + } + + urls.push(new URL('/.well-known/oauth-protected-resource', serverUrl.origin)); + return urls; +} + +function buildAuthorizationServerMetadataUrls(authorizationServerUrl: URL): URL[] { + const hasPath = authorizationServerUrl.pathname !== '/'; + + if (!hasPath) { + return [ + new URL('/.well-known/oauth-authorization-server', authorizationServerUrl.origin), + new URL('/.well-known/openid-configuration', authorizationServerUrl.origin), + ]; + } + + const pathname = authorizationServerUrl.pathname.endsWith('/') + ? authorizationServerUrl.pathname.slice(0, -1) + : authorizationServerUrl.pathname; + + return [ + new URL(`/.well-known/oauth-authorization-server${pathname}`, authorizationServerUrl.origin), + new URL('/.well-known/oauth-authorization-server', authorizationServerUrl.origin), + new URL(`/.well-known/openid-configuration${pathname}`, authorizationServerUrl.origin), + new URL(`${pathname}/.well-known/openid-configuration`, authorizationServerUrl.origin), + ]; +} + +function normalizeUrlForOutput(url: URL): string { + return url.toString().replace(/\/$/, ''); +} + +function extractResourceMetadataUrl(response: Response): URL | undefined { + const header = response.headers.get('www-authenticate'); + if (!header) { + return undefined; + } + + if (!header.toLowerCase().startsWith('bearer ')) { + return undefined; + } + + const match = header.match(/resource_metadata="([^"]+)"/); + if (!match) { + return undefined; + } + + try { + return new URL(match[1]); + } catch { + return undefined; + } +} + +async function fetchJson(url: URL, fetchFn: typeof fetch): Promise { + const response = await fetchFn(url, { headers: getMetadataHeaders() }); + + if (!response.ok) { + return undefined; + } + + return response.json(); +} + +async function fetchMetadataByPriority( + urls: URL[], + fetchFn: typeof fetch, + schema: z.ZodType, +): Promise { + const metadataPromises = urls.map(async (url) => { + try { + const json = await fetchJson(url, fetchFn); + const metadata = schema.safeParse(json); + return metadata.success ? metadata.data : undefined; + } catch { + return undefined; + } + }); + + for (const metadataPromise of metadataPromises) { + const metadata = await metadataPromise; + if (metadata) { + return metadata; + } + } + + return undefined; +} + +async function discoverProtectedResourceMetadata(serverUrl: URL, fetchFn: typeof fetch) { + const challengeMetadataPromise = (async () => { + try { + const response = await fetchFn(serverUrl, { headers: getMetadataHeaders() }); + const resourceMetadataUrl = extractResourceMetadataUrl(response); + if (!resourceMetadataUrl) { + return undefined; + } + + const json = await fetchJson(resourceMetadataUrl, fetchFn); + const metadata = protectedResourceMetadataSchema.safeParse(json); + return metadata.success ? metadata.data : undefined; + } catch { + return undefined; + } + })(); + + const wellKnownMetadata = await fetchMetadataByPriority( + buildProtectedResourceMetadataUrls(serverUrl), + fetchFn, + protectedResourceMetadataSchema, + ); + if (wellKnownMetadata) { + return wellKnownMetadata; + } + + return challengeMetadataPromise; +} + +async function discoverAuthorizationServerMetadata(authorizationServerUrl: URL, fetchFn: typeof fetch) { + return fetchMetadataByPriority( + buildAuthorizationServerMetadataUrls(authorizationServerUrl), + fetchFn, + authorizationServerMetadataSchema, + ); +} + +export async function checkMcpServerDcrSupport(serverUrl: string, fetchFn: typeof fetch = fetch): Promise { + const parsedServerUrl = new URL(serverUrl); + const protectedResourceMetadata = await discoverProtectedResourceMetadata(parsedServerUrl, fetchFn); + const authorizationServerUrls = protectedResourceMetadata?.authorization_servers?.length + ? protectedResourceMetadata.authorization_servers + : [parsedServerUrl.toString()]; + + let foundAuthorizationServerMetadata = false; + let firstAuthorizationServerUrl: URL | undefined; + for (const authorizationServer of authorizationServerUrls) { + const authorizationServerUrl = new URL(authorizationServer); + firstAuthorizationServerUrl ??= authorizationServerUrl; + const authorizationServerMetadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, fetchFn); + if (!authorizationServerMetadata) { + continue; + } + + foundAuthorizationServerMetadata = true; + if (authorizationServerMetadata.registration_endpoint) { + return { + supportsDcr: true, + isKnown: true, + authorizationServerUrl: normalizeUrlForOutput(authorizationServerUrl), + registrationEndpoint: authorizationServerMetadata.registration_endpoint, + }; + } + } + + if (foundAuthorizationServerMetadata) { + return { + supportsDcr: false, + isKnown: true, + authorizationServerUrl: firstAuthorizationServerUrl + ? normalizeUrlForOutput(firstAuthorizationServerUrl) + : undefined, + }; + } + + return { + supportsDcr: true, + isKnown: false, + authorizationServerUrl: firstAuthorizationServerUrl + ? normalizeUrlForOutput(firstAuthorizationServerUrl) + : undefined, + }; +} diff --git a/packages/web/src/ee/features/mcp/errors.ts b/packages/web/src/ee/features/mcp/errors.ts new file mode 100644 index 000000000..12a0c79a9 --- /dev/null +++ b/packages/web/src/ee/features/mcp/errors.ts @@ -0,0 +1,10 @@ +import { ErrorCode } from '@/lib/errorCodes'; +import { ServiceError } from '@/lib/serviceError'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { StatusCodes } from 'http-status-codes'; + +export const oauthNotSupported = (): ServiceError => ({ + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE, +}); diff --git a/packages/web/src/ee/features/mcp/externalMcpError.test.ts b/packages/web/src/ee/features/mcp/externalMcpError.test.ts new file mode 100644 index 000000000..5f51433b5 --- /dev/null +++ b/packages/web/src/ee/features/mcp/externalMcpError.test.ts @@ -0,0 +1,66 @@ +import { describe, expect, test } from 'vitest'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; + +describe('getExternalMcpErrorLogFields', () => { + test('does not include raw error messages or response bodies', () => { + class OAuthProviderError extends Error { + statusCode = 401; + response = { + status: 401, + body: JSON.stringify({ + error: 'invalid_client', + error_description: 'client_secret=client-secret refresh_token=refresh-token', + }), + }; + } + const error = new OAuthProviderError('invalid_client client_secret=client-secret'); + + const fields = getExternalMcpErrorLogFields(error); + + expect(fields).toEqual({ + errorClass: 'OAuthProviderError', + errorName: 'Error', + oauthError: 'invalid_client', + statusCode: 401, + }); + expect(JSON.stringify(fields)).not.toContain('client-secret'); + expect(JSON.stringify(fields)).not.toContain('refresh-token'); + }); + + test('drops unsafe custom names', () => { + const fields = getExternalMcpErrorLogFields({ + name: 'client_secret=client-secret', + status: 502, + }); + + expect(fields).toEqual({ + errorClass: 'Object', + statusCode: 502, + }); + expect(JSON.stringify(fields)).not.toContain('client-secret'); + }); + + test('preserves known safe diagnostic reasons without raw messages', () => { + const fields = getExternalMcpErrorLogFields( + new Error('Incompatible auth server: does not support dynamic client registration'), + ); + + expect(fields).toEqual({ + errorClass: 'Error', + reason: 'dynamic_client_registration_unsupported', + }); + expect(JSON.stringify(fields)).not.toContain('Incompatible auth server'); + }); + + test('finds allowlisted OAuth codes anywhere in a message', () => { + const fields = getExternalMcpErrorLogFields( + new Error('Request failed at invalid_grant after token exchange'), + ); + + expect(fields).toEqual({ + errorClass: 'Error', + oauthError: 'invalid_grant', + }); + expect(JSON.stringify(fields)).not.toContain('Request failed'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/externalMcpError.ts b/packages/web/src/ee/features/mcp/externalMcpError.ts new file mode 100644 index 000000000..4894a317d --- /dev/null +++ b/packages/web/src/ee/features/mcp/externalMcpError.ts @@ -0,0 +1,174 @@ +interface SafeExternalMcpErrorFields { + errorClass: string; + errorName?: string; + oauthError?: string; + reason?: string; + statusCode?: number; +} + +const OAUTH_ERROR_CODES = new Set([ + 'invalid_request', + 'invalid_client', + 'invalid_grant', + 'unauthorized_client', + 'unsupported_grant_type', + 'invalid_scope', + 'server_error', + 'temporarily_unavailable', +]); + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function safeIdentifier(value: unknown): string | undefined { + if (typeof value !== 'string') { + return undefined; + } + + if (!/^[A-Za-z0-9_.:-]{1,80}$/.test(value)) { + return undefined; + } + + return value; +} + +function numericStatus(value: unknown): number | undefined { + if (typeof value !== 'number' || !Number.isInteger(value)) { + return undefined; + } + + if (value < 100 || value > 599) { + return undefined; + } + + return value; +} + +function getStatusCode(error: unknown): number | undefined { + if (!isRecord(error)) { + return undefined; + } + + return numericStatus(error.statusCode) + ?? numericStatus(error.status) + ?? (isRecord(error.response) ? numericStatus(error.response.status) : undefined); +} + +function safeOAuthErrorCode(value: unknown): string | undefined { + const identifier = safeIdentifier(value); + if (!identifier) { + return undefined; + } + + const normalized = identifier.toLowerCase(); + return OAUTH_ERROR_CODES.has(normalized) ? normalized : undefined; +} + +function getErrorMessage(error: unknown): string | undefined { + if (error instanceof Error) { + return error.message; + } + + return isRecord(error) && typeof error.message === 'string' ? error.message : undefined; +} + +function getConstructorOAuthErrorCode(error: unknown): string | undefined { + if (!isRecord(error)) { + return undefined; + } + + const constructor = error.constructor; + if (!isRecord(constructor)) { + return undefined; + } + + return safeOAuthErrorCode(constructor.errorCode); +} + +function getBodyOAuthErrorCode(body: unknown): string | undefined { + if (typeof body !== 'string' || body.length > 4096) { + return undefined; + } + + try { + const parsed = JSON.parse(body); + return isRecord(parsed) ? safeOAuthErrorCode(parsed.error) : undefined; + } catch { + return undefined; + } +} + +function getMessageOAuthErrorCode(error: unknown): string | undefined { + const tokens = getErrorMessage(error)?.match(/\b[a-z_]{3,40}\b/g); + return tokens?.find((token) => OAUTH_ERROR_CODES.has(token)); +} + +function getOAuthErrorCode(error: unknown): string | undefined { + if (!isRecord(error)) { + return undefined; + } + + return safeOAuthErrorCode(error.error) + ?? safeOAuthErrorCode(error.code) + ?? safeOAuthErrorCode(error.errorCode) + ?? getConstructorOAuthErrorCode(error) + ?? getBodyOAuthErrorCode(error.body) + ?? (isRecord(error.response) ? getBodyOAuthErrorCode(error.response.body) : undefined) + ?? getMessageOAuthErrorCode(error); +} + +function getSafeReason(error: unknown): string | undefined { + const message = getErrorMessage(error)?.toLowerCase(); + if (!message) { + return undefined; + } + + if (message.includes('does not support dynamic client registration')) { + return 'dynamic_client_registration_unsupported'; + } + if (message.includes('does not support grant type')) { + return 'unsupported_grant_type'; + } + if (message.includes('does not support response type')) { + return 'unsupported_response_type'; + } + if (message.includes('does not support code challenge method') || message.includes('does not support s256 code challenge')) { + return 'unsupported_code_challenge_method'; + } + if (message.includes('oauth state parameter mismatch')) { + return 'oauth_state_mismatch'; + } + if (message.includes('oauth client information must be saveable') || message.includes('existing oauth client information is required')) { + return 'missing_oauth_client_information'; + } + + return undefined; +} + +/** + * Returns log-safe metadata for errors thrown by external MCP/OAuth libraries. + * + * Do not log raw error objects, messages, stacks, response bodies, request bodies, + * or causes from these boundaries. A malicious or misconfigured provider can echo + * client secrets or tokens into OAuth error bodies. + */ +export function getExternalMcpErrorLogFields(error: unknown): SafeExternalMcpErrorFields { + const errorClass = error instanceof Error + ? safeIdentifier(error.constructor.name) ?? 'Error' + : safeIdentifier(isRecord(error) ? error.constructor?.name : undefined) ?? 'UnknownExternalMcpError'; + const errorName = error instanceof Error + ? safeIdentifier(error.name) + : safeIdentifier(isRecord(error) ? error.name : undefined); + const oauthError = getOAuthErrorCode(error); + const reason = getSafeReason(error); + const statusCode = getStatusCode(error); + + return { + errorClass, + ...(errorName && errorName !== errorClass ? { errorName } : {}), + ...(oauthError ? { oauthError } : {}), + ...(reason ? { reason } : {}), + ...(statusCode ? { statusCode } : {}), + }; +} diff --git a/packages/web/src/ee/features/mcp/hooks/useConnectMcp.ts b/packages/web/src/ee/features/mcp/hooks/useConnectMcp.ts new file mode 100644 index 000000000..184a0d047 --- /dev/null +++ b/packages/web/src/ee/features/mcp/hooks/useConnectMcp.ts @@ -0,0 +1,39 @@ +'use client'; + +import { useState } from 'react'; +import { useToast } from '@/components/hooks/use-toast'; +import { useQueryClient } from '@tanstack/react-query'; +import { connectMcpToAsk } from '@/app/api/(client)/client'; +import { invalidateMcpConfigurationQueries } from '@/ee/features/mcp/queryKeys'; +import { isServiceError } from '@/lib/utils'; + +export function useConnectMcp() { + const [loadingServerId, setLoadingServerId] = useState(null); + const { toast } = useToast(); + const queryClient = useQueryClient(); + + const connect = async (serverId: string) => { + setLoadingServerId(serverId); + const result = await connectMcpToAsk({ serverId }); + + if (isServiceError(result)) { + toast({ + description: `Failed to connect MCP server. ${result.message}`, + }); + setLoadingServerId(null); + return; + } + + if (result.authorizationUrl) { + window.location.href = result.authorizationUrl; + } else { + toast({ + description: 'MCP server is already connected.', + }); + await invalidateMcpConfigurationQueries(queryClient); + setLoadingServerId(null); + } + }; + + return { connect, loadingServerId }; +} diff --git a/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts b/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts new file mode 100644 index 000000000..9d8f999e6 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts @@ -0,0 +1,130 @@ +import { expect, test, describe, vi } from 'vitest'; +import { prisma } from '@/__mocks__/prisma'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +// --- Mocks --- + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + env: { AUTH_URL: 'http://localhost:3000' }, + decryptOAuthToken: vi.fn((s: string) => s), +})); + +vi.mock('server-only', () => ({ default: vi.fn() })); + +vi.mock('@/features/mcp/prismaOAuthClientProvider', () => ({ + PrismaOAuthClientProvider: vi.fn(), +})); + +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ + StreamableHTTPClientTransport: vi.fn(), +})); + +// Import after mocks are set up +const { isTokenExpiredWithNoRefresh, getConnectedMcpClients } = await import('./mcpClientFactory'); + +// --- Helpers --- + +const PAST = new Date('2020-01-01'); +const FUTURE = new Date('2099-01-01'); + +const TOKEN_NO_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer' }; +const TOKEN_WITH_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer', refresh_token: 'ref' }; + +function makeUserServer(overrides: { + tokens?: OAuthTokens; + tokensExpiresAt?: Date | null; + orgId?: number; +}) { + return { + serverId: 'srv-1', + userId: 'user-1', + tokens: JSON.stringify(overrides.tokens ?? TOKEN_NO_REFRESH), + tokensExpiresAt: overrides.tokensExpiresAt ?? null, + server: { + orgId: overrides.orgId ?? 1, + name: 'MyServer', + sanitizedName: 'myserver', + serverUrl: 'https://example.com/mcp', + }, + }; +} + +// --- isTokenExpiredWithNoRefresh --- + +describe('isTokenExpiredWithNoRefresh', () => { + test('returns true when access token is expired and no refresh token', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, PAST)).toBe(true); + }); + + test('returns false when refresh_token is present even if access token is expired', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_WITH_REFRESH, PAST)).toBe(false); + }); + + test('returns false when tokensExpiresAt is null', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, null)).toBe(false); + }); + + test('returns false when access token has not yet expired', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, FUTURE)).toBe(false); + }); +}); + +// --- getConnectedMcpClients --- + +describe('getConnectedMcpClients', () => { + test('skips server when access token expired and no refresh token', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_NO_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(0); + }); + + test('includes server when refresh_token present even if access token expired', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_WITH_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(1); + }); + + test('includes server when tokensExpiresAt is null', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokensExpiresAt: null }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(1); + }); + + test('skips server belonging to a different org', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ orgId: 999 }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(0); + }); + + test('returns server metadata from the user MCP server row', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_WITH_REFRESH }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result[0]).toMatchObject({ + serverId: 'srv-1', + serverName: 'MyServer', + sanitizedName: 'myserver', + serverUrl: 'https://example.com/mcp', + }); + }); +}); diff --git a/packages/web/src/ee/features/mcp/mcpClientFactory.ts b/packages/web/src/ee/features/mcp/mcpClientFactory.ts new file mode 100644 index 000000000..996969529 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpClientFactory.ts @@ -0,0 +1,115 @@ +import { createLogger, env, decryptOAuthToken } from '@sourcebot/shared'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import type { OAuthTokens } from '@ai-sdk/mcp'; +import type { PrismaClient } from '@sourcebot/db'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; + +const logger = createLogger('mcp-client-factory'); + +export interface McpToolSet { + serverId: string; + serverName: string; + sanitizedName: string; + serverUrl: string; + transport: StreamableHTTPClientTransport; +} + +/** + * Returns true if the access token is definitely expired and there is no refresh token to fall back on. + */ +export function isTokenExpiredWithNoRefresh(tokens: OAuthTokens, tokensExpiresAt: Date | null): boolean { + if (tokens.refresh_token) { + return false; + } + if (!tokensExpiresAt) { + return false; + } + return new Date() > tokensExpiresAt; +} + +/** + * Creates authenticated transports for all external MCP servers the user has valid credentials for. + * Skips servers with clearly expired tokens and no refresh token. + * Does NOT connect — connection is deferred to createMCPClient. + */ +export async function getConnectedMcpClients(prisma: PrismaClient, userId: string, orgId: number): Promise { + const userServers = await prisma.userMcpServer.findMany({ + where: { + userId, + tokens: { not: null }, + server: { + orgId, + clientInfo: { not: null }, + }, + }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + server: { + select: { + orgId: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }, + }, + }); + + const clients: McpToolSet[] = []; + + for (const userServer of userServers) { + // Skip servers that don't belong to the current org. + if (userServer.server.orgId !== orgId) { + continue; + } + + const serverName = userServer.server.name; + + try { + const decrypted = decryptOAuthToken(userServer.tokens); + if (!decrypted) { + logger.warn(`Could not decrypt tokens for MCP server ${serverName}, skipping.`); + continue; + } + + const tokens: OAuthTokens = JSON.parse(decrypted); + + if (isTokenExpiredWithNoRefresh(tokens, userServer.tokensExpiresAt)) { + logger.warn(`Access token for MCP server ${serverName} is expired and has no refresh token. User ${userId} needs to re-authorize.`); + continue; + } + + const provider = new PrismaOAuthClientProvider({ + prisma, + serverId: userServer.serverId, + orgId, + userId, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + }); + + const transport = new StreamableHTTPClientTransport( + new URL(userServer.server.serverUrl), + { authProvider: provider }, + ); + + clients.push({ + serverId: userServer.serverId, + serverName, + sanitizedName: userServer.server.sanitizedName, + serverUrl: userServer.server.serverUrl, + transport, + }); + } catch (error) { + logger.error('Failed to prepare MCP server transport.', { + serverId: userServer.serverId, + sanitizedName: userServer.server.sanitizedName, + error: getExternalMcpErrorLogFields(error), + }); + } + } + + return clients; +} diff --git a/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts b/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts new file mode 100644 index 000000000..20918f066 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts @@ -0,0 +1,185 @@ +import { expect, test, describe } from 'vitest'; +import { buildMcpToolRegistry, searchMcpTools, McpToolRegistryEntry } from './mcpToolRegistry'; + +// Helper to create a mock tool record matching the MCPClient['tools'] return type. +function createToolRecord(tools: Record) { + const record: Record = {}; + for (const [name, tool] of Object.entries(tools)) { + record[name] = { + description: tool.description, + execute: tool.execute ?? (() => {}), + inputSchema: {}, + }; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return record as any; +} + +describe('buildMcpToolRegistry', () => { + test('extracts serverName from namespaced tool name', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: 'List issues' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List issues', serverName: 'linear' }, + ]); + }); + + test('handles underscores in server name', () => { + const tools = createToolRecord({ + 'mcp_my_server__get_data': { description: 'Get data' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe('my_server'); + }); + + test('defaults missing description to empty string', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: undefined }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].description).toBe(''); + }); + + test('non-matching tool name yields empty serverName', () => { + const tools = createToolRecord({ + 'some_random_tool': { description: 'A tool' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe(''); + }); + + test('empty tools record returns empty array', () => { + const registry = buildMcpToolRegistry(createToolRecord({})); + + expect(registry).toEqual([]); + }); +}); + +describe('searchMcpTools', () => { + // Shared registry for most tests. + const registry: McpToolRegistryEntry[] = [ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + { name: 'mcp_linear__create_issue', description: 'Create a new issue', serverName: 'linear' }, + { name: 'mcp_linear__update_issue', description: 'Update an existing issue', serverName: 'linear' }, + { name: 'mcp_github__search_repos', description: 'Search repositories on GitHub', serverName: 'github' }, + { name: 'mcp_pg__run_query', description: 'Run a database query', serverName: 'pg' }, + { name: 'mcp_slack__send_message', description: 'Send a message to a Slack channel', serverName: 'slack' }, + { name: 'mcp_jira__create_ticket', description: 'Create a new Jira ticket', serverName: 'jira' }, + ]; + + test('exact name match returns single result', () => { + const results = searchMcpTools('mcp_linear__list_issues', registry); + + expect(results).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + ]); + }); + + test('token matching on tool name', () => { + const results = searchMcpTools('list issues', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_linear__list_issues'); + }); + + test('synonym expansion: "find" matches tools with "list"', () => { + const results = searchMcpTools('find issues', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); + + test('synonym expansion: "add" matches tools with "create"', () => { + const results = searchMcpTools('add ticket', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_jira__create_ticket'); + }); + + test('reverse expansion: canonical "list" expands to synonyms', () => { + // "list" is canonical and expands to "find", "get", "fetch", "search", etc. + const results = searchMcpTools('list repos', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + // "search_repos" should match because "list" expands to "search" + expect(names).toContain('mcp_github__search_repos'); + }); + + test('higher-scoring entries come first', () => { + // "create issue" should score higher for create_issue than for list_issues + const results = searchMcpTools('create issue', registry); + + expect(results.length).toBeGreaterThan(1); + // The first result should be the one that matches both tokens + expect(results[0].name).toBe('mcp_linear__create_issue'); + }); + + test('topK limits results', () => { + const results = searchMcpTools('issue', registry, 2); + + expect(results.length).toBeLessThanOrEqual(2); + }); + + test('default topK is 5', () => { + // All 7 entries match "mcp" as a substring, but we need tokens > 2 chars + // Use a query that matches many entries + const largeRegistry: McpToolRegistryEntry[] = Array.from({ length: 10 }, (_, i) => ({ + name: `mcp_server__tool_${i}`, + description: `Tool number ${i} for testing`, + serverName: 'server', + })); + + const results = searchMcpTools('tool testing', largeRegistry); + + expect(results.length).toBeLessThanOrEqual(5); + }); + + test('short/empty query fallback returns first topK entries', () => { + // "do it" — all tokens are <= 2 chars after filtering + const results = searchMcpTools('do it', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('empty string query fallback returns first topK entries', () => { + const results = searchMcpTools('', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('returns empty array when no tokens match', () => { + const results = searchMcpTools('xyznonexistent', registry); + + expect(results).toEqual([]); + }); + + test('search matches in description, not just name', () => { + const results = searchMcpTools('database', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_pg__run_query'); + }); + + test('tokens shorter than 3 chars are filtered out', () => { + // "do a list" → only "list" survives (length > 2) + const results = searchMcpTools('do a list', registry); + + expect(results.length).toBeGreaterThan(0); + // Should still find results via the "list" token + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/mcpToolRegistry.ts b/packages/web/src/ee/features/mcp/mcpToolRegistry.ts new file mode 100644 index 000000000..431710e9e --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolRegistry.ts @@ -0,0 +1,99 @@ +import type { MCPClient } from '@ai-sdk/mcp'; + +export interface McpToolRegistryEntry { + name: string; + description: string; + serverName: string; +} + +type McpToolRecord = Awaited>; + +// Synonym map for common action words. Expands query tokens so that e.g. +// "find tickets" matches a tool named "list_issues". +// Module-level constant — built once at server startup, never re-created. +const SYNONYM_MAP: Record = { + list: ['find', 'get', 'fetch', 'retrieve', 'search', 'show', 'query', 'read'], + create: ['make', 'add', 'post', 'open', 'new', 'submit', 'write'], + update: ['edit', 'modify', 'change', 'patch', 'set'], + delete: ['remove', 'destroy', 'archive', 'close'], + send: ['post', 'publish', 'notify', 'message'], + issue: ['ticket', 'bug', 'task', 'item', 'work'], + comment: ['note', 'reply', 'respond'], + user: ['member', 'person', 'assignee'], + project: ['repo', 'repository', 'workspace'], +}; + +// Reverse lookup: synonym → canonical token. Built once from SYNONYM_MAP. +const REVERSE_SYNONYMS: Record = {}; +for (const [canonical, synonyms] of Object.entries(SYNONYM_MAP)) { + for (const synonym of synonyms) { + REVERSE_SYNONYMS[synonym] = canonical; + } +} + +function expandTokens(tokens: string[]): string[] { + const expanded = new Set(tokens); + for (const token of tokens) { + const canonical = REVERSE_SYNONYMS[token]; + if (canonical) { + expanded.add(canonical); + } + const synonyms = SYNONYM_MAP[token]; + if (synonyms) { + for (const s of synonyms) { + expanded.add(s); + } + } + } + return Array.from(expanded); +} + +export function buildMcpToolRegistry(tools: McpToolRecord): McpToolRegistryEntry[] { + return Object.entries(tools).map(([name, tool]) => { + const match = name.match(/^mcp_(.+?)__/); + const serverName = match ? match[1] : ''; + return { + name, + description: tool.description ?? '', + serverName, + }; + }); +} + +export function searchMcpTools( + query: string, + registry: McpToolRegistryEntry[], + topK = 5, +): McpToolRegistryEntry[] { + // Fast path: if the query is an exact tool name, return it directly. + const exactMatch = registry.find(e => e.name === query); + if (exactMatch) { + return [exactMatch]; + } + + const rawTokens = query + .toLowerCase() + .split(/\W+/) + .filter(t => t.length > 2); + + // If no meaningful tokens remain (e.g. query is "do it" — all tokens <= 2 chars), + // fall back to returning the first topK tools rather than returning nothing. + // We could potentially return nothing or return another tool that will help search better + // in the future. + if (rawTokens.length === 0) { + return registry.slice(0, topK); + } + + const tokens = expandTokens(rawTokens); + + return registry + .map(entry => { + const haystack = `${entry.name} ${entry.description}`.toLowerCase(); + const score = tokens.filter(t => haystack.includes(t)).length; + return { entry, score }; + }) + .filter(({ score }) => score > 0) + .sort((a, b) => b.score - a.score) + .slice(0, topK) + .map(({ entry }) => entry); +} \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/mcpToolSets.test.ts b/packages/web/src/ee/features/mcp/mcpToolSets.test.ts new file mode 100644 index 000000000..ebbdbfacc --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolSets.test.ts @@ -0,0 +1,303 @@ +import { expect, test, describe, vi, beforeEach } from 'vitest'; +import type { McpToolSet } from './mcpClientFactory'; + +// --- Mocks --- + +const mockCreateMCPClient = vi.fn(); +const mockLogger = vi.hoisted(() => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), +})); + +vi.mock('@ai-sdk/mcp', () => ({ + createMCPClient: (...args: unknown[]) => mockCreateMCPClient(...args), +})); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => mockLogger, + env: { + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, +})); + +vi.mock('ai', () => ({ + jsonSchema: vi.fn((schema: unknown, opts: unknown) => ({ schema, ...(opts as object) })), +})); + +// --- Helpers --- + +interface MockToolDef { + name: string; + description?: string; + inputSchema?: Record; + annotations?: Record; +} + +function createMockMcpClient(toolDefs: MockToolDef[]) { + const toolRecord: Record; description: string | undefined; inputSchema: unknown }> = {}; + for (const def of toolDefs) { + toolRecord[def.name] = { + execute: vi.fn().mockResolvedValue({ content: [{ type: 'text', text: 'result' }] }), + description: def.description, + inputSchema: def.inputSchema ?? {}, + }; + } + + return { + listTools: vi.fn().mockResolvedValue({ tools: toolDefs }), + toolsFromDefinitions: vi.fn().mockReturnValue(toolRecord), + close: vi.fn().mockResolvedValue(undefined), + tools: vi.fn().mockResolvedValue(toolRecord), + }; +} + +function createMockClient(overrides: Partial & { serverName: string }): McpToolSet { + return { + serverId: 'server-id', + sanitizedName: overrides.serverName.toLowerCase(), + serverUrl: `https://${overrides.serverName.toLowerCase()}.example.com/mcp`, + transport: {} as McpToolSet['transport'], + ...overrides, + }; +} + +// --- Tests --- + +// Import after mocks are set up +const { getMcpTools } = await import('./mcpToolSets'); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('getMcpTools', () => { + test('single server with single tool produces correctly namespaced key', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + expect(result.failedServers).toEqual([]); + }); + + test('multiple servers produce tools with distinct prefixes', async () => { + const linearClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + const githubClient = createMockMcpClient([ + { name: 'search_repos', description: 'Search repos' }, + ]); + + mockCreateMCPClient + .mockResolvedValueOnce(linearClient) + .mockResolvedValueOnce(githubClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + createMockClient({ serverName: 'GitHub' }), + ]); + + const toolNames = Object.keys(result.tools); + expect(toolNames).toContain('mcp_linear__list_issues'); + expect(toolNames).toContain('mcp_github__search_repos'); + }); + + test('read-only tool does NOT get needsApproval', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues', annotations: { readOnlyHint: true } }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__list_issues']; + expect(tool).toBeDefined(); + expect('needsApproval' in tool).toBe(false); + }); + + test('non-read-only tool gets needsApproval: true', async () => { + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + expect(tool).toBeDefined(); + expect(tool).toHaveProperty('needsApproval', true); + }); + + test('failed server connection adds to failedServers array', async () => { + const error = new Error('Connection refused client_secret=client-secret access_token=access-token'); + Object.assign(error, { + response: { + status: 502, + body: 'client_secret=client-secret access_token=access-token', + }, + }); + mockCreateMCPClient.mockRejectedValue(error); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual([]); + expect(mockLogger.error).toHaveBeenCalledWith('Failed to get tools from MCP server.', { + serverId: 'server-id', + sanitizedName: 'brokenserver', + error: { + errorClass: 'Error', + statusCode: 502, + }, + }); + expect(JSON.stringify(mockLogger.error.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mockLogger.error.mock.calls)).not.toContain('access-token'); + }); + + test('failed server does not prevent other servers from working', async () => { + const goodClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + + mockCreateMCPClient + .mockRejectedValueOnce(new Error('Connection refused')) + .mockResolvedValueOnce(goodClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + createMockClient({ serverName: 'Linear' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + }); + + test('generates favicon URL from server URL origin', async () => { + const mockClient = createMockMcpClient([ + { name: 'tool', description: 'A tool' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear', serverUrl: 'https://api.linear.app/mcp' }), + ]); + + expect(result.serverFaviconUrls['linear']).toBe( + 'https://www.google.com/s2/favicons?domain=https://api.linear.app&sz=32' + ); + }); + + test('cleanup function calls close on all clients', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + await result.cleanup(); + + expect(client1.close).toHaveBeenCalledOnce(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('cleanup handles errors in close gracefully', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + client1.close.mockRejectedValue(new Error('Close failed')); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + // Should not throw + await expect(result.cleanup()).resolves.toBeUndefined(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('empty clients array returns empty result', async () => { + const result = await getMcpTools([]); + + expect(result.tools).toEqual({}); + expect(result.failedServers).toEqual([]); + expect(result.serverFaviconUrls).toEqual({}); + expect(typeof result.cleanup).toBe('function'); + }); + + test('tool schema validation rejects invalid input', async () => { + const mockClient = createMockMcpClient([ + { + name: 'create_issue', + description: 'Create issue', + inputSchema: { + type: 'object', + properties: { title: { type: 'string' } }, + }, + }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + // The inputSchema should have a validate function from our jsonSchema mock + const schema = tool.inputSchema as { validate?: (value: unknown) => Promise<{ success: boolean; error?: Error }> }; + expect(schema.validate).toBeDefined(); + + if (schema.validate) { + // Valid input + const validResult = await schema.validate({ title: 'My Issue' }); + expect(validResult.success).toBe(true); + + // Invalid input (extra property not allowed because additionalProperties: false) + const invalidResult = await schema.validate({ title: 'My Issue', bogus: 'field' }); + expect(invalidResult.success).toBe(false); + } + }); + + test('tool execute wrapper propagates non-timeout errors', async () => { + const originalError = new Error('External API failed'); + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + // Override the execute to reject + const toolRecord = mockClient.toolsFromDefinitions(); + toolRecord['create_issue'].execute.mockRejectedValue(originalError); + + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + await expect( + tool.execute({}, { messages: [], toolCallId: 'test' }) + ).rejects.toThrow('External API failed'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/mcpToolSets.ts b/packages/web/src/ee/features/mcp/mcpToolSets.ts new file mode 100644 index 000000000..febae502c --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolSets.ts @@ -0,0 +1,157 @@ +import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; +import { McpToolSet } from './mcpClientFactory'; +import { createLogger, env } from '@sourcebot/shared'; +import Ajv from 'ajv'; +import { jsonSchema, ToolExecutionOptions } from 'ai'; +import type { JSONSchema7, JSONSchema7Definition } from 'json-schema'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; +import { getMcpFaviconUrl } from './utils'; + +const logger = createLogger('mcp-tool-sets'); +const ajv = new Ajv({ allErrors: true, strict: false }); + +class McpToolTimeoutError extends Error { + constructor(toolName: string, timeoutMs: number) { + super(`MCP tool "${toolName}" timed out after ${timeoutMs}ms`); + this.name = 'McpToolTimeoutError'; + } +} + +export interface McpToolsResult { + tools: Record>[string]>; + failedServers: string[]; + serverFaviconUrls: Record; + cleanup: () => Promise; +} + +/** + * Creates MCPClients from authenticated transports, retrieves their tools, + * and returns a namespaced tool record + cleanup function. + */ +export async function getMcpTools(clients: McpToolSet[]): Promise { + const allTools: McpToolsResult['tools'] = {}; + const failedServers: string[] = []; + const serverFaviconUrls: Record = {}; + const mcpClients: MCPClient[] = []; + + const connectionTimeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + for (const { serverId, serverName, sanitizedName, serverUrl, transport } of clients) { + try { + const mcpClient = await Promise.race([ + createMCPClient({ transport }), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Connection to MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + mcpClients.push(mcpClient); + + const toolDefinitions = await Promise.race([ + mcpClient.listTools(), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Listing tools from MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + const tools = mcpClient.toolsFromDefinitions(toolDefinitions); + const prefix = `mcp_${sanitizedName}`; + + for (const [toolName, tool] of Object.entries(tools)) { + const def = toolDefinitions.tools.find(t => t.name === toolName); + const isReadOnly = (def?.annotations as Record | undefined)?.readOnlyHint === true; + + // The @ai-sdk/mcp library sets additionalProperties: false in the JSON schema + // sent to the model, but does NOT provide a validate function — so the AI SDK + // skips server-side validation entirely. We compile the schema with ajv to + // enforce parameter names at runtime, which allows experimental_repairToolCall + // to fire on InvalidToolInputError. + const rawSchema = def?.inputSchema ?? { type: 'object', properties: {} }; + const schema = { + ...rawSchema, + type: 'object' as const, + properties: (rawSchema.properties ?? {}) as Record, + additionalProperties: false, + } satisfies JSONSchema7; + const validate = ajv.compile(schema); + const validProperties = Object.keys(schema.properties); + const validatedInputSchema = jsonSchema(schema, { + validate: async (value: unknown) => { + if (validate(value)) { + return { success: true as const, value }; + } + return { + success: false as const, + error: new Error( + `${ajv.errorsText(validate.errors)}. The valid parameter names for this tool are: [${validProperties.join(', ')}]` + ), + }; + }, + }); + + const originalExecute = tool.execute; + const qualifiedName = `${prefix}__${toolName}`; + const timeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + const executeWithTimeout = (async (input: unknown, options: ToolExecutionOptions) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const combinedSignal = options.abortSignal + ? AbortSignal.any([options.abortSignal, timeoutSignal]) + : timeoutSignal; + + try { + return await originalExecute(input, { + ...options, + abortSignal: combinedSignal, + }); + } catch (error) { + if (timeoutSignal.aborted) { + logger.warn(`MCP tool "${qualifiedName}" timed out after ${timeoutMs}ms`); + throw new McpToolTimeoutError(qualifiedName, timeoutMs); + } + throw error; + } + }) as typeof originalExecute; + + allTools[qualifiedName] = { + ...tool, + execute: executeWithTimeout, + // The @ai-sdk/mcp package bundles its own copy of @ai-sdk/provider-utils, + // so its Schema isn't structurally identical to the workspace copy. + // The runtime shape is the same; cast through `any` to bridge the duplicate + // type identity (the two FlexibleSchema types differ only by their internal + // schemaSymbol brand). + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputSchema: validatedInputSchema as any, + ...(isReadOnly ? {} : { needsApproval: true }), + }; + } + + const faviconUrl = getMcpFaviconUrl(serverUrl, serverName); + if (faviconUrl) { + serverFaviconUrls[sanitizedName] = faviconUrl; + } + } catch (error) { + logger.error('Failed to get tools from MCP server.', { + serverId, + sanitizedName, + error: getExternalMcpErrorLogFields(error), + }); + failedServers.push(serverName); + } + } + + const cleanup = async () => { + await Promise.allSettled( + mcpClients.map(async (client) => { + try { + await client.close(); + } catch (error) { + logger.error('Error closing MCP client.', { + error: getExternalMcpErrorLogFields(error), + }); + } + }) + ); + }; + + return { tools: allTools, failedServers, serverFaviconUrls, cleanup }; +} diff --git a/packages/web/src/ee/features/mcp/prefabMcpServers.test.ts b/packages/web/src/ee/features/mcp/prefabMcpServers.test.ts new file mode 100644 index 000000000..18abdb0a1 --- /dev/null +++ b/packages/web/src/ee/features/mcp/prefabMcpServers.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, test } from 'vitest'; +import { + getAvailablePrefabMcpServers, + normalizeMcpServerUrlForComparison, + PREFAB_MCP_SERVERS, +} from './prefabMcpServers'; + +describe('prefab MCP servers', () => { + test('ships the supported prefab servers', () => { + expect(PREFAB_MCP_SERVERS).toEqual([ + { + id: 'atlassian', + name: 'Atlassian', + serverUrl: 'https://mcp.atlassian.com/v1/mcp/authv2', + }, + { + id: 'linear', + name: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'slack', + name: 'Slack', + serverUrl: 'https://mcp.slack.com/mcp', + }, + ]); + }); + + test('keeps prefab servers sorted alphabetically by name', () => { + const sortedNames = PREFAB_MCP_SERVERS.map((server) => server.name).sort((a, b) => a.localeCompare(b)); + + expect(PREFAB_MCP_SERVERS.map((server) => server.name)).toEqual(sortedNames); + }); + + test('hides already configured prefab servers after URL normalization', () => { + const availableServers = getAvailablePrefabMcpServers(['https://mcp.slack.com/mcp/']); + + expect(availableServers.map((server) => server.id)).toEqual(['atlassian', 'linear']); + }); + + test('hides the Atlassian prefab entry when the shared endpoint is configured', () => { + const availableServers = getAvailablePrefabMcpServers(['https://mcp.atlassian.com/v1/mcp/authv2/']); + + expect(availableServers.map((server) => server.id)).toEqual(['linear', 'slack']); + }); + + test('normalizes server URLs for duplicate comparisons', () => { + expect(normalizeMcpServerUrlForComparison(' HTTPS://MCP.SLACK.COM/mcp/#connect ')).toBe('https://mcp.slack.com/mcp'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/prefabMcpServers.ts b/packages/web/src/ee/features/mcp/prefabMcpServers.ts new file mode 100644 index 000000000..22b60bd16 --- /dev/null +++ b/packages/web/src/ee/features/mcp/prefabMcpServers.ts @@ -0,0 +1,47 @@ +export interface PrefabMcpServer { + id: string; + name: string; + serverUrl: string; +} + +const prefabMcpServers = [ + { + id: "atlassian", + name: "Atlassian", + serverUrl: "https://mcp.atlassian.com/v1/mcp/authv2", + }, + { + id: "linear", + name: "Linear", + serverUrl: "https://mcp.linear.app/mcp", + }, + { + id: "slack", + name: "Slack", + serverUrl: "https://mcp.slack.com/mcp", + }, +] satisfies PrefabMcpServer[]; + +export const PREFAB_MCP_SERVERS = [...prefabMcpServers].sort((a, b) => a.name.localeCompare(b.name)); + +export function normalizeMcpServerUrlForComparison(serverUrl: string): string { + const trimmedServerUrl = serverUrl.trim(); + + try { + const url = new URL(trimmedServerUrl); + url.hash = ""; + return url.toString().replace(/\/$/, ""); + } catch { + return trimmedServerUrl.toLowerCase().replace(/\/$/, ""); + } +} + +export function getAvailablePrefabMcpServers(configuredServerUrls: string[]): PrefabMcpServer[] { + const configuredServerUrlSet = new Set( + configuredServerUrls.map((serverUrl) => normalizeMcpServerUrlForComparison(serverUrl)), + ); + + return PREFAB_MCP_SERVERS.filter((server) => ( + !configuredServerUrlSet.has(normalizeMcpServerUrlForComparison(server.serverUrl)) + )); +} diff --git a/packages/web/src/ee/features/mcp/queryKeys.test.ts b/packages/web/src/ee/features/mcp/queryKeys.test.ts new file mode 100644 index 000000000..f897f486a --- /dev/null +++ b/packages/web/src/ee/features/mcp/queryKeys.test.ts @@ -0,0 +1,16 @@ +import { describe, expect, test, vi } from 'vitest'; +import type { QueryClient } from '@tanstack/react-query'; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from './queryKeys'; + +describe('invalidateMcpConfigurationQueries', () => { + test('invalidates both admin configuration and account MCP server status', async () => { + const queryClient = { + invalidateQueries: vi.fn().mockResolvedValue(undefined), + } as unknown as QueryClient; + + await invalidateMcpConfigurationQueries(queryClient); + + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: mcpQueryKeys.configuration }); + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: mcpQueryKeys.serversWithStatus }); + }); +}); diff --git a/packages/web/src/ee/features/mcp/queryKeys.ts b/packages/web/src/ee/features/mcp/queryKeys.ts new file mode 100644 index 000000000..469c9fc04 --- /dev/null +++ b/packages/web/src/ee/features/mcp/queryKeys.ts @@ -0,0 +1,13 @@ +import type { QueryClient } from '@tanstack/react-query'; + +export const mcpQueryKeys = { + serversWithStatus: ['mcpServersWithStatus'] as const, + configuration: ['mcpConfiguration'] as const, +}; + +export async function invalidateMcpConfigurationQueries(queryClient: QueryClient) { + await Promise.all([ + queryClient.invalidateQueries({ queryKey: mcpQueryKeys.configuration }), + queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }), + ]); +} diff --git a/packages/web/src/ee/features/mcp/types.ts b/packages/web/src/ee/features/mcp/types.ts new file mode 100644 index 000000000..6ddff31e4 --- /dev/null +++ b/packages/web/src/ee/features/mcp/types.ts @@ -0,0 +1,17 @@ +export interface McpConfigurationServer { + id: string; + name: string; + serverUrl: string; + sanitizedName: string; + faviconUrl: string | undefined; + savedConnectionCount: number; +} + +export type McpConfigurationAllowedMode = 'approved_only'; + +export interface GetMcpConfigurationResponse { + servers: McpConfigurationServer[]; + totalSavedConnectionCount: number; + allowedMode: McpConfigurationAllowedMode; + isOAuthAvailable: boolean; +} diff --git a/packages/web/src/ee/features/mcp/utils.test.ts b/packages/web/src/ee/features/mcp/utils.test.ts new file mode 100644 index 000000000..d3c887fc7 --- /dev/null +++ b/packages/web/src/ee/features/mcp/utils.test.ts @@ -0,0 +1,50 @@ +import { expect, test, describe } from 'vitest'; +import { getMcpFaviconUrl, sanitizeMcpServerName } from './utils'; + +describe('sanitizeMcpServerName', () => { + test('lowercases ASCII letters', () => { + expect(sanitizeMcpServerName('MyServer')).toBe('myserver'); + }); + + test('replaces special characters with underscores', () => { + expect(sanitizeMcpServerName('My Server!')).toBe('my_server_'); + }); + + test('preserves digits', () => { + expect(sanitizeMcpServerName('server123')).toBe('server123'); + }); + + test('replaces spaces and hyphens', () => { + expect(sanitizeMcpServerName('my-cool server')).toBe('my_cool_server'); + }); + + test('handles empty string', () => { + expect(sanitizeMcpServerName('')).toBe(''); + }); + + test('replaces unicode characters with underscores', () => { + expect(sanitizeMcpServerName('Ñoño')).toBe('_o_o'); + }); + + test('replaces all special characters', () => { + expect(sanitizeMcpServerName('@#$%')).toBe('____'); + }); + + test('returns already sanitized name unchanged', () => { + expect(sanitizeMcpServerName('linear')).toBe('linear'); + }); +}); + +describe('getMcpFaviconUrl', () => { + test('returns a Google favicon URL for a valid server URL', () => { + expect(getMcpFaviconUrl('https://mcp.linear.app/mcp')).toBe('https://www.google.com/s2/favicons?domain=https://mcp.linear.app&sz=32'); + }); + + test('returns a local Atlassian icon for the Atlassian prefab server', () => { + expect(getMcpFaviconUrl('https://mcp.atlassian.com/v1/mcp/authv2', 'Atlassian')).toMatch(/^data:image\/svg\+xml,/); + }); + + test('returns undefined for a malformed server URL', () => { + expect(getMcpFaviconUrl('not a url')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/ee/features/mcp/utils.ts b/packages/web/src/ee/features/mcp/utils.ts new file mode 100644 index 000000000..3cfd4dfeb --- /dev/null +++ b/packages/web/src/ee/features/mcp/utils.ts @@ -0,0 +1,51 @@ +/** + * Sanitizes an MCP server name into a lowercase alphanumeric string suitable + * for use as a tool-name prefix (e.g. "My Server!" → "my_server_"). + * + * This is used to namespace MCP tools (mcp_{sanitizedName}__{toolName}) and + * to key favicon maps. Must be kept consistent everywhere — collisions on + * this value are prevented at server-creation time. + */ +export function sanitizeMcpServerName(name: string): string { + return name.toLowerCase().replace(/[^a-z0-9]/g, '_'); +} + +function createMcpIconDataUri(svg: string): string { + return `data:image/svg+xml,${encodeURIComponent(svg)}`; +} + +const atlassianIconSvg = ` + + + + + + + + + + + + + +`; + +const knownMcpFaviconUrlsBySanitizedName: Record = { + atlassian: createMcpIconDataUri(atlassianIconSvg), +}; + +export function getMcpFaviconUrl(serverUrl: string, serverName?: string): string | undefined { + if (serverName) { + const knownFaviconUrl = knownMcpFaviconUrlsBySanitizedName[sanitizeMcpServerName(serverName)]; + if (knownFaviconUrl) { + return knownFaviconUrl; + } + } + + try { + const origin = new URL(serverUrl).origin; + return `https://www.google.com/s2/favicons?domain=${origin}&sz=32`; + } catch { + return undefined; + } +} diff --git a/packages/web/src/features/chat/agent.ts b/packages/web/src/features/chat/agent.ts index 0efb706fc..859f21428 100644 --- a/packages/web/src/features/chat/agent.ts +++ b/packages/web/src/features/chat/agent.ts @@ -4,19 +4,29 @@ import { getFileSource } from '@/features/git'; import { isServiceError } from "@/lib/utils"; import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; import { ProviderOptions } from "@ai-sdk/provider-utils"; +import type { PrismaClient } from "@sourcebot/db"; import { createLogger, env } from "@sourcebot/shared"; import { + convertToModelMessages, createUIMessageStream, JSONValue, LanguageModel, ModelMessage, StopCondition, streamText, StreamTextResult, UIMessageStreamOnFinishCallback, UIMessageStreamOptions, - UIMessageStreamWriter + UIMessageStreamWriter, + tool, + Tool, + NoSuchToolError, } from "ai"; +import { z } from "zod"; import { randomUUID } from "crypto"; import _dedent from "dedent"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "./constants"; import { Source } from "./types"; import { addLineNumbers, fileReferenceToString } from "./utils"; import { createTools } from "./tools"; +import { getConnectedMcpClients } from "@/ee/features/mcp/mcpClientFactory"; +import { getMcpTools, McpToolsResult } from "@/ee/features/mcp/mcpToolSets"; +import { buildMcpToolRegistry, McpToolRegistryEntry, searchMcpTools } from "@/ee/features/mcp/mcpToolRegistry"; +import { hasEntitlement } from '@/lib/entitlements'; const dedent = _dedent.withOptions({ alignValues: true }); @@ -36,6 +46,10 @@ interface CreateMessageStreamResponseProps { chatId: string; messages: SBChatMessage[]; selectedRepos: string[]; + prisma: PrismaClient; + // When undefined, MCP tools are disabled entirely (e.g. programmatic callers like askCodebase). + // When an array, MCP tools are enabled for all servers not in the list. + disabledMcpServerIds?: string[]; model: AISDKLanguageModelV3; modelName: string; onFinish: UIMessageStreamOnFinishCallback; @@ -43,6 +57,8 @@ interface CreateMessageStreamResponseProps { modelProviderOptions?: Record>; modelTemperature?: number; metadata?: Partial; + userId?: string; + orgId?: number; } export const createMessageStream = async ({ @@ -50,12 +66,16 @@ export const createMessageStream = async ({ messages, metadata, selectedRepos, + prisma, + disabledMcpServerIds, model, modelName, modelProviderOptions, modelTemperature, onFinish, onError, + userId, + orgId, }: CreateMessageStreamResponseProps) => { const latestMessage = messages[messages.length - 1]; const sources = latestMessage.parts @@ -66,7 +86,7 @@ export const createMessageStream = async ({ // Extract user messages and assistant answers. // We will use this as the context we carry between messages. - const messageHistory = + let messageHistory: ModelMessage[] = messages.map((message): ModelMessage | undefined => { if (message.role === 'user') { return { @@ -86,6 +106,28 @@ export const createMessageStream = async ({ } }).filter(message => message !== undefined); + // When the last assistant turn has approval responses (from the tool approval flow), + // the turn is incomplete — it has no answer text, only a pending tool call that was + // approved. We need to preserve the full tool call + approval so streamText can + // execute the approved tool and continue. + const lastMsg = messages[messages.length - 1]; + const hasApprovalResponses = lastMsg?.role === 'assistant' && + lastMsg.parts.some(p => p.type === 'dynamic-tool' && p.state === 'approval-responded'); + + // When continuing after tool approval, capture the prior turn's metadata + // so we can aggregate token counts and response times across phases. + const priorMetadata = hasApprovalResponses + ? (lastMsg.metadata as SBChatMessageMetadata | undefined) + : undefined; + + if (hasApprovalResponses) { + const fullLastTurn = await convertToModelMessages( + [lastMsg], + { ignoreIncompleteToolCalls: true } + ); + messageHistory = [...messageHistory, ...fullLastTurn]; + } + const stream = createUIMessageStream({ execute: async ({ writer }) => { writer.write({ @@ -101,17 +143,34 @@ export const createMessageStream = async ({ inputMessages: messageHistory, inputSources: sources, selectedRepos, + disabledMcpServerIds, onWriteSource: (source) => { writer.write({ type: 'data-source', data: source, }); }, + onMcpServerDiscovered: (sanitizedName, faviconUrl) => { + writer.write({ + type: 'data-mcp-server', + data: { sanitizedName, faviconUrl }, + }); + }, + onMcpServerFailed: (serverName) => { + writer.write({ + type: 'data-mcp-failed-server', + data: { serverName }, + }); + }, traceId, chatId, + prisma, + userId, + orgId, }); await mergeStreamAsync(researchStream, writer, { + originalMessages: messages, sendReasoning: true, sendStart: false, sendFinish: false, @@ -122,10 +181,10 @@ export const createMessageStream = async ({ writer.write({ type: 'message-metadata', messageMetadata: { - totalTokens: totalUsage.totalTokens, - totalInputTokens: totalUsage.inputTokens, - totalOutputTokens: totalUsage.outputTokens, - totalResponseTimeMs: new Date().getTime() - startTime.getTime(), + totalTokens: (priorMetadata?.totalTokens ?? 0) + (totalUsage.totalTokens ?? 0), + totalInputTokens: (priorMetadata?.totalInputTokens ?? 0) + (totalUsage.inputTokens ?? 0), + totalOutputTokens: (priorMetadata?.totalOutputTokens ?? 0) + (totalUsage.outputTokens ?? 0), + totalResponseTimeMs: (priorMetadata?.totalResponseTimeMs ?? 0) + (new Date().getTime() - startTime.getTime()), modelName, traceId, ...metadata, @@ -149,11 +208,17 @@ interface AgentOptions { providerOptions?: ProviderOptions; temperature?: number; selectedRepos: string[]; + disabledMcpServerIds?: string[]; inputMessages: ModelMessage[]; inputSources: Source[]; onWriteSource: (source: Source) => void; + onMcpServerDiscovered: (sanitizedName: string, faviconUrl: string) => void; + onMcpServerFailed: (serverName: string) => void; traceId: string; chatId: string; + prisma: PrismaClient; + userId?: string; + orgId?: number; } const createAgentStream = async ({ @@ -163,9 +228,15 @@ const createAgentStream = async ({ inputMessages, inputSources, selectedRepos, + disabledMcpServerIds, onWriteSource, + onMcpServerDiscovered, + onMcpServerFailed, traceId, - chatId, + chatId: _chatId, + prisma, + userId, + orgId, }: AgentOptions) => { // For every file source, resolve the source code so that we can include it in the system prompt. const fileSources = inputSources.filter((source) => source.type === 'file'); @@ -192,48 +263,162 @@ const createAgentStream = async ({ })) ).filter((source) => source !== undefined); + let mcpToolSetsObj: McpToolsResult = { tools: {}, failedServers: [], serverFaviconUrls: {}, cleanup: async () => {} }; + if (userId && orgId && await hasEntitlement('oauth') && disabledMcpServerIds !== undefined) { + try { + const allMcpClients = await getConnectedMcpClients(prisma, userId, orgId); + const mcpClients = allMcpClients.filter((c) => !disabledMcpServerIds.includes(c.serverId)); + mcpToolSetsObj = await getMcpTools(mcpClients); + + for (const [sanitizedName, faviconUrl] of Object.entries(mcpToolSetsObj.serverFaviconUrls)) { + onMcpServerDiscovered(sanitizedName, faviconUrl); + } + + if (mcpClients.length > 0) { + logger.info(`Connected to ${mcpClients.length} external MCP server(s): ${mcpClients.map(c => c.serverName).join(', ')}`); + } + } catch (error) { + logger.error('Failed to connect external MCP servers:', error); + } + } + + for (const serverName of mcpToolSetsObj.failedServers) { + onMcpServerFailed(serverName); + } + + const mcpRegistry = buildMcpToolRegistry(mcpToolSetsObj.tools); + const hasMcpTools = mcpRegistry.length > 0; + + const toolRequestActivation = tool({ + description: dedent` + Activate an MCP tool by name so it becomes callable on your next step. + You MUST pass an exact tool name from the tool registry in the system prompt. + Do NOT pass natural language descriptions or sentences. + If you need multiple tools, call this once per tool. + + Examples: + CORRECT: tool_to_activate_name="mcp_linear__save_comment" + CORRECT: tool_to_activate_name="mcp_linear__create_attachment" + INCORRECT: tool_to_activate_name="create a linear issue and update status" + INCORRECT: tool_to_activate_name="find tools for commenting on issues" + `, + inputSchema: z.object({ + tool_to_activate_name: z.string().describe('Exact tool name from the registry, e.g. "mcp_linear__save_comment"'), + }), + execute: async ({ tool_to_activate_name }) => { + const results = searchMcpTools(tool_to_activate_name, mcpRegistry); + return { + results: results.map(e => ({ name: e.name, description: e.description })), + }; + }, + }); + const systemPrompt = createPrompt({ repos: selectedRepos, files: resolvedFileSources, + mcpToolRegistry: mcpRegistry, }); - const stream = streamText({ - model, - providerOptions, - messages: inputMessages, - system: systemPrompt, - tools: createTools({ source: 'sourcebot-ask-agent', selectedRepos }), - temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, - stopWhen: [ - stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), - ], - toolChoice: "auto", - onStepFinish: ({ toolResults }) => { - toolResults.forEach(({ output, dynamic }) => { - if (dynamic || isServiceError(output)) { - return; + const builtinTools = createTools({ source: 'sourcebot-ask-agent', selectedRepos }); + const builtinToolNames = Object.keys(builtinTools); + const allTools: Record = { + ...builtinTools, + ...(hasMcpTools ? { tool_request_activation: toolRequestActivation, ...mcpToolSetsObj.tools } : {}), + }; + + try { + const stream = streamText({ + model, + providerOptions, + messages: inputMessages, + system: systemPrompt, + tools: allTools, + activeTools: [ + ...builtinToolNames, + ...(hasMcpTools ? ['tool_request_activation'] : []), + ], + prepareStep: hasMcpTools ? ({ steps }) => { + const activated = new Set(); + for (const step of steps) { + for (const result of step.toolResults) { + if (!result || result.toolName !== 'tool_request_activation') { + continue; + } + const output = result.output as { results?: Array<{ name: string }> }; + for (const { name } of output?.results ?? []) { + if (name in mcpToolSetsObj.tools) { + activated.add(name); + } + } + } + } + return { + activeTools: [ + ...builtinToolNames, + 'tool_request_activation', + ...Array.from(activated), + ], + }; + } : undefined, + temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, + stopWhen: [ + stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), + ], + toolChoice: "auto", + experimental_repairToolCall: async ({ toolCall, tools, error }) => { + // Fix case mismatches (e.g. model outputs "Mcp_Linear__Save_Comment" instead of "mcp_linear__save_comment") + if (NoSuchToolError.isInstance(error)) { + const lower = toolCall.toolName.toLowerCase(); + if (lower !== toolCall.toolName && lower in tools) { + return { ...toolCall, toolName: lower }; + } } - output.sources?.forEach(onWriteSource); - }); - }, - experimental_telemetry: { - isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', - metadata: { - langfuseTraceId: traceId, + // For anything we can't fix, return null. + // The AI SDK will mark the call as invalid and pass the error + // back to the model so it can retry with correct parameters. + logger.warn(`Tool call repair failed for "${toolCall.toolName}": ${error.message}`); + return null; }, - }, - onError: (error) => { - logger.error(error); - }, - }); + onStepFinish: ({ toolResults }) => { + toolResults.forEach(({ output, dynamic }) => { + if (dynamic || isServiceError(output)) { + return; + } - return stream; + output.sources?.forEach(onWriteSource); + }); + }, + experimental_telemetry: { + isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', + metadata: { + langfuseTraceId: traceId, + }, + }, + onError: (error) => { + logger.error(error); + }, + }); + + // Clean up MCP transport connections once the stream completes (success or failure). + stream.response.then( + () => mcpToolSetsObj.cleanup(), + () => mcpToolSetsObj.cleanup() + ); + return stream; + } catch (error) { + // If anything between MCP setup and stream return throws, ensure we + // still close the MCP transport connections to avoid leaking them. + await mcpToolSetsObj.cleanup(); + throw error; + } } + const createPrompt = ({ files, repos, + mcpToolRegistry, }: { files?: { path: string; @@ -243,6 +428,7 @@ const createPrompt = ({ revision: string; }[], repos: string[], + mcpToolRegistry: McpToolRegistryEntry[], }) => { return dedent` You are a powerful agentic AI code assistant built into Sourcebot, the world's best code-intelligence platform. Your job is to help developers understand and navigate their large codebases. @@ -287,6 +473,18 @@ const createPrompt = ({ `: ''} + ${(mcpToolRegistry.length > 0) ? dedent` + + External MCP tools are available but must first be activated via \`tool_request_activation\`. + + **CRITICAL**: The list below is the complete and authoritative inventory of all tools available to you: + ${mcpToolRegistry.map(e => `- ${e.name}: ${e.description}`).join('\n')} + + **How to use tool_request_activation**: Pass the exact tool name from the list above as the \`tool_to_activate_name\` parameter. Do NOT pass natural language descriptions or sentences. If you need multiple tools, call \`tool_request_activation\` once per tool. + Example: to activate the comment tool, call \`tool_request_activation\` with tool_to_activate_name="mcp_linear__save_comment", NOT tool_to_activate_name="save a comment on an issue". + + ` : ''} + When you have sufficient context, output your answer as a structured markdown response. diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.test.ts b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.test.ts new file mode 100644 index 000000000..5170d3c60 --- /dev/null +++ b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.test.ts @@ -0,0 +1,17 @@ +import { describe, expect, test } from 'vitest'; +import { splitMcpServersForChatMenu } from './chatBoxPlusButton'; + +describe('splitMcpServersForChatMenu', () => { + test('keeps connected and expired servers separate from connectable approved servers', () => { + const servers = [ + { id: 'connected', isConnected: true, isAuthExpired: false }, + { id: 'expired', isConnected: false, isAuthExpired: true }, + { id: 'approved', isConnected: false, isAuthExpired: false }, + ]; + + const { connectedServers, connectableServers } = splitMcpServersForChatMenu(servers); + + expect(connectedServers.map((server) => server.id)).toEqual(['connected', 'expired']); + expect(connectableServers.map((server) => server.id)).toEqual(['approved']); + }); +}); diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx new file mode 100644 index 000000000..6a484a083 --- /dev/null +++ b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx @@ -0,0 +1,305 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Switch } from "@/components/ui/switch"; +import { connectMcpToAsk, getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { mcpQueryKeys } from "@/ee/features/mcp/queryKeys"; +import { isServiceError } from "@/lib/utils"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { AlertTriangleIcon, Loader2Icon, PlusCircleIcon, PlusIcon, RefreshCwIcon, ServerIcon, SettingsIcon } from "lucide-react"; +import { PlusButtonInfoCard } from "./plusButtonInfoCard"; +import { useRouter } from "next/navigation"; +import { useEffect, useRef, useState } from "react"; +import { useSlate } from "slate-react"; +import { Editor } from "slate"; +import type { CustomEditor, SearchScope } from "@/features/chat/types"; +import { + clearMcpOAuthDraft, + consumeMcpOAuthDraftForPath, + createMcpOAuthDraftPath, + saveMcpOAuthDraft, +} from "@/features/chat/mcpOAuthDraft"; +import { clearEditorHistory, resetEditor } from "@/features/chat/utils"; + +interface ChatBoxPlusButtonProps { + selectedSearchScopes: SearchScope[]; + onSelectedSearchScopesChange: (items: SearchScope[]) => void; + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; +} + +interface ChatMenuMcpServer { + isConnected: boolean; + isAuthExpired: boolean; +} + +export function splitMcpServersForChatMenu(servers: T[]) { + return { + connectedServers: servers.filter((server) => server.isConnected || server.isAuthExpired), + connectableServers: servers.filter((server) => !server.isConnected && !server.isAuthExpired), + }; +} + +function restoreEditorChildren(editor: CustomEditor, children: CustomEditor['children']) { + editor.children = children; + editor.selection = { + anchor: Editor.end(editor, []), + focus: Editor.end(editor, []), + }; + clearEditorHistory(editor); + editor.onChange(); +} + +export const ChatBoxPlusButton = ({ + selectedSearchScopes, + onSelectedSearchScopesChange, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, +}: ChatBoxPlusButtonProps) => { + const [connectingServerId, setConnectingServerId] = useState(null); + const editor = useSlate(); + const hasRestoredMcpOAuthDraft = useRef(false); + const isMountedRef = useRef(false); + const queryClient = useQueryClient(); + const router = useRouter(); + const { toast } = useToast(); + + const { data: servers = [], isError, isLoading, refetch } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load MCP servers"); + } + return result; + }, + }); + + useEffect(() => { + isMountedRef.current = true; + + return () => { + isMountedRef.current = false; + }; + }, []); + + useEffect(() => { + if (hasRestoredMcpOAuthDraft.current) { + return; + } + + const currentPath = createMcpOAuthDraftPath(window.location.pathname, window.location.search); + if (!currentPath) { + return; + } + + const draft = consumeMcpOAuthDraftForPath(currentPath); + if (!draft) { + return; + } + + hasRestoredMcpOAuthDraft.current = true; + + try { + restoreEditorChildren(editor, draft.children); + onSelectedSearchScopesChange(draft.selectedSearchScopes); + onDisabledMcpServerIdsChange(draft.disabledMcpServerIds); + } catch (error) { + resetEditor(editor); + editor.onChange(); + console.error('Failed to restore MCP OAuth draft:', error); + } + }, [editor, onDisabledMcpServerIdsChange, onSelectedSearchScopesChange]); + + const onToggle = (serverId: string, checked: boolean) => { + if (checked) { + onDisabledMcpServerIdsChange(disabledMcpServerIds.filter((id) => id !== serverId)); + } else { + onDisabledMcpServerIdsChange([...disabledMcpServerIds, serverId]); + } + }; + + const handleConnect = async (serverId: string) => { + setConnectingServerId(serverId); + const returnTo = createMcpOAuthDraftPath(window.location.pathname, window.location.search) ?? '/chat'; + + saveMcpOAuthDraft({ + returnTo, + children: editor.children, + selectedSearchScopes, + disabledMcpServerIds, + }); + + try { + const result = await connectMcpToAsk({ + serverId, + returnTo, + }); + + if (!isMountedRef.current) { + return; + } + + if (isServiceError(result)) { + clearMcpOAuthDraft(); + toast({ + description: `Failed to connect MCP server. ${result.message}`, + variant: "destructive", + }); + setConnectingServerId(null); + return; + } + + if (result.authorizationUrl) { + window.location.href = result.authorizationUrl; + return; + } + + clearMcpOAuthDraft(); + toast({ description: 'MCP server is already connected.' }); + await queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }); + if (!isMountedRef.current) { + return; + } + setConnectingServerId(null); + } catch { + if (!isMountedRef.current) { + return; + } + + clearMcpOAuthDraft(); + toast({ + description: "Failed to connect MCP server.", + variant: "destructive", + }); + setConnectingServerId(null); + return; + } + }; + + const { connectedServers, connectableServers } = splitMcpServersForChatMenu(servers); + const hasServers = connectedServers.length > 0 || connectableServers.length > 0; + + return ( + + + + + + + + + + + + e.preventDefault()}> + + + + MCP Servers + + + {isError && !hasServers ? ( + { + e.preventDefault(); + refetch(); + }} + className="gap-2 text-destructive" + > + + Failed to load. Retry? + + ) : isLoading ? ( + + Loading MCP servers... + + ) : !hasServers ? ( + + No MCP servers available + + ) : ( + <> + {connectedServers.map((server) => { + const isEnabled = !server.isAuthExpired && !disabledMcpServerIds.includes(server.id); + return ( + e.preventDefault()} + disabled={server.isAuthExpired} + className="flex items-center justify-between gap-2" + > +
+ {server.isAuthExpired ? ( + + ) : ( + + )} + {server.name} +
+ onToggle(server.id, checked)} + disabled={server.isAuthExpired} + className="scale-75" + /> +
+ ); + })} + {connectedServers.length > 0 && connectableServers.length > 0 && } + {connectableServers.map((server) => ( + { + e.preventDefault(); + void handleConnect(server.id); + }} + disabled={connectingServerId !== null} + className="group flex cursor-pointer items-center justify-between gap-2" + > +
+ + {server.name} +
+ {connectingServerId === server.id ? ( + + ) : ( + + )} +
+ ))} + + )} + + router.push(`/settings/mcpServers`)} + > + + Manage MCP servers + +
+
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx index a0aae38cf..9edf84cb1 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx @@ -5,6 +5,7 @@ import { LanguageModelInfo, SearchScope } from "@/features/chat/types"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useSelectedLanguageModel } from "../../useSelectedLanguageModel"; import { AtMentionButton } from "./atMentionButton"; +import { ChatBoxPlusButton } from "./chatBoxPlusButton"; import { LanguageModelSelector } from "./languageModelSelector"; import { SearchScopeSelector } from "./searchScopeSelector"; @@ -16,6 +17,10 @@ export interface ChatBoxToolbarProps { onSelectedSearchScopesChange: (items: SearchScope[]) => void; isContextSelectorOpen: boolean; onContextSelectorOpenChanged: (isOpen: boolean) => void; + // TODO_Jack_MakeLinearTask: Make the plus button available on simplified toolbar usages (e.g. askgh) + // once additional features (beyond MCP server toggling) are added to it. + disabledMcpServerIds?: string[]; + onDisabledMcpServerIdsChange?: (ids: string[]) => void; } export const ChatBoxToolbar = ({ @@ -26,6 +31,8 @@ export const ChatBoxToolbar = ({ onSelectedSearchScopesChange, isContextSelectorOpen, onContextSelectorOpenChanged, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, }: ChatBoxToolbarProps) => { const { selectedLanguageModel, setSelectedLanguageModel } = useSelectedLanguageModel({ languageModels, @@ -33,6 +40,17 @@ export const ChatBoxToolbar = ({ return ( <> + {disabledMcpServerIds !== undefined && onDisabledMcpServerIdsChange !== undefined && ( + <> + + + + )} { + return ( +
+
+ +

Extra Features

+
+
+ Add MCP servers, include files and more. +
+
+ ); +}; \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/features/chat/components/chatThread/chatThread.tsx index f60d281b7..af0aee3cc 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThread.tsx @@ -7,10 +7,10 @@ import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; import { useChat } from '@ai-sdk/react'; -import { CreateUIMessage, DefaultChatTransport } from 'ai'; +import { CreateUIMessage, DefaultChatTransport, lastAssistantMessageIsCompleteWithApprovalResponses } from 'ai'; import { ArrowDownIcon, CopyIcon } from 'lucide-react'; import { useNavigationGuard } from 'next-navigation-guard'; -import { Fragment, useCallback, useEffect, useRef, useState } from 'react'; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useStickToBottom } from 'use-stick-to-bottom'; import { Descendant } from 'slate'; import { useMessagePairs } from '../../useMessagePairs'; @@ -19,12 +19,15 @@ import { ChatBox } from '../chatBox'; import { ChatBoxToolbar } from '../chatBox/chatBoxToolbar'; import { ChatThreadListItem } from './chatThreadListItem'; import { ErrorBanner } from './errorBanner'; +import { McpFailedServersBanner } from './mcpFailedServersBanner'; import { useRouter } from 'next/navigation'; import { usePrevious } from '@uidotdev/usehooks'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; import { duplicateChat, generateAndUpdateChatNameFromMessage } from '../../actions'; import { isServiceError } from '@/lib/utils'; import { NotConfiguredErrorBanner } from '../notConfiguredErrorBanner'; +import { McpServerIconContext, McpServerIconMap } from '../../mcpServerIconContext'; +import { ToolApprovalProvider } from '../../toolApprovalContext'; import useCaptureEvent from '@/hooks/useCaptureEvent'; import { SignInPromptBanner } from './signInPromptBanner'; import { DuplicateChatDialog } from '@/app/(app)/chat/components/duplicateChatDialog'; @@ -47,6 +50,8 @@ interface ChatThreadProps { searchContexts: SearchContextQuery[]; selectedSearchScopes: SearchScope[]; onSelectedSearchScopesChange: (items: SearchScope[]) => void; + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; isOwner?: boolean; isAuthenticated?: boolean; chatName?: string; @@ -61,6 +66,8 @@ export const ChatThread = ({ searchContexts, selectedSearchScopes, onSelectedSearchScopesChange, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, isOwner = true, isAuthenticated = false, chatName, @@ -86,13 +93,66 @@ export const ChatThread = ({ ) ?? [] ); + const [mcpServerIconMap, setMcpServerIconMap] = useState(() => { + const map: McpServerIconMap = {}; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-server') + .forEach((part) => { + map[part.data.sanitizedName] = part.data.faviconUrl; + }); + }); + return map; + }); + + const [failedMcpServers, setFailedMcpServers] = useState(() => { + const names: string[] = []; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-failed-server') + .forEach((part) => { + if (!names.includes(part.data.serverName)) { + names.push(part.data.serverName); + } + }); + }); + return names; + }); + const [isFailedMcpBannerVisible, setIsFailedMcpBannerVisible] = useState(false); + const { selectedLanguageModel } = useSelectedLanguageModel({ languageModels, }); + // Refs to capture the latest request params for the transport body. + // The transport is created once (useMemo) but params change over time, + // so refs ensure the dynamic body function always reads current values. + const searchScopesRef = useRef(selectedSearchScopes); + const modelRef = useRef(selectedLanguageModel); + const disabledMcpRef = useRef(disabledMcpServerIds); + + useEffect(() => { searchScopesRef.current = selectedSearchScopes; }, [selectedSearchScopes]); + useEffect(() => { modelRef.current = selectedLanguageModel; }, [selectedLanguageModel]); + useEffect(() => { disabledMcpRef.current = disabledMcpServerIds; }, [disabledMcpServerIds]); + + // Transport with dynamic body — resolved on every request (including auto-resends + // triggered by sendAutomaticallyWhen after tool approval). + const transport = useMemo(() => new DefaultChatTransport({ + api: '/api/chat', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: () => ({ + selectedSearchScopes: searchScopesRef.current, + languageModel: modelRef.current, + disabledMcpServerIds: disabledMcpRef.current, + }), + }), []); + const { messages, sendMessage: _sendMessage, + addToolApprovalResponse, error, status, stop, @@ -100,17 +160,28 @@ export const ChatThread = ({ } = useChat({ id: defaultChatId, messages: initialMessages, - transport: new DefaultChatTransport({ - api: '/api/chat', - headers: { - 'X-Sourcebot-Client-Source': 'sourcebot-web-client', - }, - }), + transport, + sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithApprovalResponses, onData: (dataPart) => { // Keeps sources added by the assistant in sync. if (dataPart.type === 'data-source') { setSources((prev) => [...prev, dataPart.data]); } + if (dataPart.type === 'data-mcp-server') { + setMcpServerIconMap((prev) => ({ + ...prev, + [dataPart.data.sanitizedName]: dataPart.data.faviconUrl, + })); + } + if (dataPart.type === 'data-mcp-failed-server') { + setFailedMcpServers((prev) => { + if (prev.includes(dataPart.data.serverName)) { + return prev; + } + return [...prev, dataPart.data.serverName]; + }); + setIsFailedMcpBannerVisible(true); + } } }); @@ -133,6 +204,7 @@ export const ChatThread = ({ body: { selectedSearchScopes, languageModel: selectedLanguageModel, + disabledMcpServerIds, } satisfies AdditionalChatRequestParams, }); @@ -162,6 +234,7 @@ export const ChatThread = ({ selectedLanguageModel, _sendMessage, selectedSearchScopes, + disabledMcpServerIds, messages.length, toast, chatId, @@ -231,13 +304,13 @@ export const ChatThread = ({ const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes); + const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes, disabledMcpServerIds); sendMessage(message); scrollToBottom(); } catch (error) { console.error('Failed to restore pending message:', error); } - }, [isAuthenticated, isOwner, chatId, sendMessage, selectedSearchScopes, scrollToBottom]); + }, [isAuthenticated, isOwner, chatId, sendMessage, selectedSearchScopes, disabledMcpServerIds, scrollToBottom]); // Track scroll position for history state restoration. useEffect(() => { @@ -319,13 +392,13 @@ export const ChatThread = ({ const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes); + const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes, disabledMcpServerIds); sendMessage(message); scrollToBottom(); resetEditor(editor); - }, [sendMessage, selectedSearchScopes, isAuthenticated, captureEvent, chatId, scrollToBottom]); + }, [sendMessage, selectedSearchScopes, disabledMcpServerIds, isAuthenticated, captureEvent, chatId, scrollToBottom]); const onDuplicate = useCallback(async (newName: string): Promise => { if (!defaultChatId) { @@ -347,7 +420,8 @@ export const ChatThread = ({ }, [defaultChatId, toast, router, captureEvent]); return ( - <> + + {error && ( setIsErrorBannerVisible(false)} /> )} + setIsFailedMcpBannerVisible(false)} + />
@@ -480,6 +561,7 @@ export const ChatThread = ({ providers={loginWallProviders} callbackUrl={typeof window !== 'undefined' ? window.location.href : ''} /> - + + ); } diff --git a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx index 0cbd4b264..f56bd8f8b 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx @@ -6,11 +6,13 @@ import { Skeleton } from '@/components/ui/skeleton'; import { CheckCircle, Loader2 } from 'lucide-react'; import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import scrollIntoView from 'scroll-into-view-if-needed'; +import { DynamicToolUIPart } from "ai"; import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; import { useExtractReferences } from '../../useExtractReferences'; import { getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences, tryResolveFileReference } from '../../utils'; import { AnswerCard } from './answerCard'; import { DetailsCard } from './detailsCard'; +import { ToolApprovalBanner } from './toolApprovalBanner'; import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer'; import { ReferencedSourcesListView } from './referencedSourcesListView'; import isEqual from "fast-deep-equal/react"; @@ -106,7 +108,8 @@ const ChatThreadListItemComponent = forwardRef { + if (!assistantMessage) { + return []; + } + return assistantMessage.parts.filter( + (part): part is DynamicToolUIPart => part.type === 'dynamic-tool' && part.state === 'approval-requested' + ); + }, [assistantMessage]); + // Auto-collapse when answer first appears, but only once and respect user preference useEffect(() => { @@ -364,6 +377,10 @@ const ChatThreadListItemComponent = forwardRef + {approvalRequestedParts.length > 0 && ( + + )} + {(answerPart && assistantMessage) ? ( - ) : !isStreaming && ( + ) : !isStreaming && approvalRequestedParts.length === 0 && (

Error: No answer response was provided

)}
diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx index 0e2365ea6..5997df6e7 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx +++ b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx @@ -25,6 +25,8 @@ import { ListReposToolComponent } from './tools/listReposToolComponent'; import { ListTreeToolComponent } from './tools/listTreeToolComponent'; import { ReadFileToolComponent } from './tools/readFileToolComponent'; import { ToolOutputGuard } from './tools/toolOutputGuard'; +import { McpToolComponent } from './tools/mcpToolComponent'; +import { ToolSearchToolComponent } from './tools/toolSearchToolComponent'; interface DetailsCardProps { @@ -48,7 +50,10 @@ const DetailsCardComponent = ({ }: DetailsCardProps) => { const captureEvent = useCaptureEvent(); - const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => part.type.startsWith('tool-')).length, [thinkingSteps]); + const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => + part.type.startsWith('tool-') || + (part.type === 'dynamic-tool' && part.toolName.startsWith('mcp_')) + ).length, [thinkingSteps]); const handleExpandedChanged = useCallback((next: boolean) => { captureEvent('wa_chat_details_card_toggled', { chatId, isExpanded: next }); @@ -308,8 +313,19 @@ export const StepPartRenderer = ({ part }: { part: SBChatMessagePart }) => { {(output) => } ) - case 'data-source': + case 'tool-tool_request_activation': + if (part.state !== 'output-available') { + return Activating tool...; + } + return ; case 'dynamic-tool': + if (part.toolName.startsWith('mcp_')) { + return ; + } + return null; + case 'data-source': + case 'data-mcp-server': + case 'data-mcp-failed-server': case 'file': case 'source-document': case 'source-url': diff --git a/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx new file mode 100644 index 000000000..0c74fe72f --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx @@ -0,0 +1,43 @@ +'use client'; + +import { Button } from '@/components/ui/button'; +import { AlertTriangle, X } from 'lucide-react'; + +interface McpFailedServersBannerProps { + serverNames: string[]; + isVisible: boolean; + onClose: () => void; +} + +export const McpFailedServersBanner = ({ serverNames, isVisible, onClose }: McpFailedServersBannerProps) => { + if (!isVisible || serverNames.length === 0) { + return null; + } + + const message = serverNames.length === 1 + ? `MCP server "${serverNames[0]}" failed to load tools` + : `${serverNames.length} MCP servers failed to load tools`; + + return ( +
+
+
+
+ + + {message} + +
+ +
+
+
+ ); +}; \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx new file mode 100644 index 000000000..0724c93b7 --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx @@ -0,0 +1,101 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { useToolApproval } from "@/features/chat/toolApprovalContext"; +import { cn } from "@/lib/utils"; +import { DynamicToolUIPart } from "ai"; +import { ChevronRight } from "lucide-react"; +import { useCallback, useState } from "react"; +import { parseMcpToolName } from "./tools/mcpToolComponent"; +import { JsonHighlighter } from "./tools/jsonHighlighter"; + +interface ToolApprovalBannerProps { + parts: DynamicToolUIPart[]; +} + +export const ToolApprovalBanner = ({ parts }: ToolApprovalBannerProps) => { + const addToolApprovalResponse = useToolApproval(); + const iconMap = useMcpServerIconMap(); + + if (parts.length === 0) { + return null; + } + + return ( +
+ {parts.map((part) => ( + + ))} +
+ ); +}; + +const ToolApprovalItem = ({ + part, + addToolApprovalResponse, + iconMap, +}: { + part: DynamicToolUIPart; + addToolApprovalResponse: ReturnType; + iconMap: Record; +}) => { + const [isExpanded, setIsExpanded] = useState(false); + const parsed = parseMcpToolName(part.toolName); + const serverName = parsed?.serverName ?? part.toolName; + const toolName = parsed?.toolName ?? part.toolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const hasInput = part.state !== 'input-streaming'; + const requestText = hasInput ? JSON.stringify(part.input, null, 2) : ''; + + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const onApprove = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: true }); + } + }, [part, addToolApprovalResponse]); + + const onDeny = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: false, reason: 'User denied' }); + } + }, [part, addToolApprovalResponse]); + + return ( +
+
+ +
+ + +
+
+ {hasInput && isExpanded && ( +
+ +
+ )} +
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx new file mode 100644 index 000000000..18203a9de --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx @@ -0,0 +1,151 @@ +'use client'; + +export function unescapeJsonStrings(value: unknown): unknown { + if (typeof value === 'string') { + try { + const parsed: unknown = JSON.parse(value); + if (typeof parsed === 'object' && parsed !== null) { + return unescapeJsonStrings(parsed); + } + } catch { + // not JSON — leave as-is + } + return value; + } + if (Array.isArray(value)) { + return value.map(unescapeJsonStrings); + } + if (typeof value === 'object' && value !== null) { + return Object.fromEntries( + Object.entries(value).map(([k, v]) => [k, unescapeJsonStrings(v)]) + ); + } + return value; +} + +type TokenType = 'key' | 'string' | 'number' | 'boolean' | 'null' | 'structural' | 'whitespace' | 'other'; + +interface Token { + type: TokenType; + value: string; +} + +function tokenizeJson(text: string): Token[] { + const tokens: Token[] = []; + let i = 0; + + while (i < text.length) { + const ch = text[i]; + + // Whitespace + if (/\s/.test(ch)) { + let j = i + 1; + while (j < text.length && /\s/.test(text[j])) { + j++; + } + tokens.push({ type: 'whitespace', value: text.slice(i, j) }); + i = j; + continue; + } + + // String + if (ch === '"') { + let j = i + 1; + while (j < text.length) { + if (text[j] === '\\') { + j += 2; + } else if (text[j] === '"') { + j++; + break; + } else { + j++; + } + } + const str = text.slice(i, j); + + // Lookahead past whitespace for a colon → this is a key + let k = j; + while (k < text.length && /\s/.test(text[k])) { + k++; + } + const isKey = text[k] === ':'; + + tokens.push({ type: isKey ? 'key' : 'string', value: str }); + i = j; + continue; + } + + // Number + if (ch === '-' || /\d/.test(ch)) { + const match = text.slice(i).match(/^-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/); + if (match) { + tokens.push({ type: 'number', value: match[0] }); + i += match[0].length; + continue; + } + } + + // Boolean / null keywords + if (text.slice(i, i + 4) === 'true') { + tokens.push({ type: 'boolean', value: 'true' }); + i += 4; + continue; + } + if (text.slice(i, i + 5) === 'false') { + tokens.push({ type: 'boolean', value: 'false' }); + i += 5; + continue; + } + if (text.slice(i, i + 4) === 'null') { + tokens.push({ type: 'null', value: 'null' }); + i += 4; + continue; + } + + // Structural characters + if ('{}[]:,'.includes(ch)) { + tokens.push({ type: 'structural', value: ch }); + i++; + continue; + } + + // Fallback + tokens.push({ type: 'other', value: ch }); + i++; + } + + return tokens; +} + +const TOKEN_CLASSES: Record = { + key: 'text-editor-tag-name', + string: 'text-editor-tag-string', + number: 'text-editor-tag-number', + boolean: 'text-editor-tag-atom', + null: 'text-editor-tag-atom', + structural: 'text-muted-foreground', + whitespace: '', + other: '', +}; + +import { useMemo } from "react"; + +export const JsonHighlighter = ({ text }: { text: string }) => { + const tokens = useMemo(() => tokenizeJson(text), [text]); + + return ( +
+            {tokens.map((token, i) => {
+                const cls = TOKEN_CLASSES[token.type];
+                if (!cls) {
+                    return token.value;
+                }
+                return (
+                    
+                        {token.value}
+                    
+                );
+            })}
+        
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx new file mode 100644 index 000000000..3e679a21b --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx @@ -0,0 +1,173 @@ +'use client'; + +import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { cn } from "@/lib/utils"; +import { DynamicToolUIPart } from "ai"; +import { CheckCircle, ChevronDown, XCircle } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; + +export function parseMcpToolName(toolName: string): { serverName: string; toolName: string } | null { + if (!toolName.startsWith('mcp_')) { + return null; + } + const withoutPrefix = toolName.slice(4); + const doubleUnderscoreIdx = withoutPrefix.indexOf('__'); + if (doubleUnderscoreIdx === -1) { + return null; + } + return { + serverName: withoutPrefix.slice(0, doubleUnderscoreIdx), + toolName: withoutPrefix.slice(doubleUnderscoreIdx + 2), + }; +} + +export const McpToolComponent = ({ part }: { part: DynamicToolUIPart }) => { + const needsApproval = part.state === 'approval-requested'; + const [isExpanded, setIsExpanded] = useState(needsApproval); + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const iconMap = useMcpServerIconMap(); + const parsed = parseMcpToolName(part.toolName); + const displayName = parsed + ? `${parsed.serverName}: ${parsed.toolName}` + : part.toolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const hasInput = part.state !== 'input-streaming'; + + const requestText = useMemo( + () => hasInput ? JSON.stringify(part.input, null, 2) : '', + [hasInput, part.input] + ); + const responseText = useMemo(() => { + if (part.state === 'output-available') { + try { + return JSON.stringify(unescapeJsonStrings(part.output), null, 2); + } catch { + return String(part.output); + } + } + if (part.state === 'output-error') { + return part.errorText ?? ''; + } + return undefined; + }, [part.state, part.output, part.errorText]); + + const onCopyRequest = useCallback(() => { + navigator.clipboard.writeText(requestText); + return true; + }, [requestText]); + + const onCopyResponse = useCallback(() => { + if (!responseText) { + return false; + } + navigator.clipboard.writeText(responseText); + return true; + }, [responseText]); + + const renderStatus = () => { + if (part.state === 'output-error') { + return ( + + + {displayName} failed: {part.errorText} + + ); + } + if (part.state === 'output-denied') { + return ( + + + + {displayName} — denied + + ); + } + if (part.state === 'approval-requested') { + return ( + + + {displayName} + + ); + } + if (part.state === 'approval-responded') { + const approved = part.approval.approved; + return ( + + + {approved ? : } + {displayName}{approved ? '...' : ' — denied'} + + ); + } + if (part.state === 'output-available') { + return ( + + + {displayName} + + ); + } + // input-streaming, input-available, or other in-progress states + return ( + + + {displayName}... + + ); + }; + + return ( +
+
+
+ {renderStatus()} +
+ {hasInput && ( + + )} +
+ {hasInput && isExpanded && ( +
+ + + + {responseText !== undefined && ( + <> +
+ +
+ +
+
+ + )} +
+ )} +
+ ); +}; + + +const ResultSection = ({ label, onCopy, children }: { label: string; onCopy: () => boolean; children: React.ReactNode }) => ( +
+
+ {label} + +
+
+ {children} +
+
+); diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx index aac756f4a..43ce2021d 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx +++ b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx @@ -6,6 +6,7 @@ import { ToolUIPart } from "ai"; import { ChevronDown } from "lucide-react"; import { cn } from "@/lib/utils"; import { useCallback, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; export const ToolOutputGuard = >({ part, @@ -27,7 +28,7 @@ export const ToolOutputGuard = { const raw = (part.output as { output: string }).output; try { - return JSON.stringify(JSON.parse(raw), null, 2); + return JSON.stringify(unescapeJsonStrings(JSON.parse(raw)), null, 2); } catch { return raw; } @@ -70,17 +71,15 @@ export const ToolOutputGuard = -
-                            {requestText}
-                        
+
{responseText !== undefined && ( <>
-
-                                    {responseText}
-                                
+
+ +
)} diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx new file mode 100644 index 000000000..3711e22bd --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx @@ -0,0 +1,53 @@ +'use client'; + +import { Separator } from "@/components/ui/separator"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { ChevronRight } from "lucide-react"; +import { useState } from "react"; +import { cn } from "@/lib/utils"; + +interface ToolSearchResult { + name: string; + description: string; +} + +interface ToolSearchToolComponentProps { + query: string; + results: ToolSearchResult[]; +} + +export const ToolSearchToolComponent = ({ query, results }: ToolSearchToolComponentProps) => { + const [isOpen, setIsOpen] = useState(false); + + return ( + + +
+ + Searched MCP tools: {query} + + {results.length} result{results.length === 1 ? '' : 's'} + +
+
+ +
+ {results.map((result) => ( +
+ {result.name} + {result.description && ( + <> + - + {result.description} + + )} +
+ ))} + {results.length === 0 && ( + No tools found + )} +
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/constants.ts b/packages/web/src/features/chat/constants.ts index b84e9d922..c258e9951 100644 --- a/packages/web/src/features/chat/constants.ts +++ b/packages/web/src/features/chat/constants.ts @@ -9,3 +9,5 @@ export const ANSWER_TAG = ''; export const SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY = 'selectedSearchScopes'; export const SET_CHAT_STATE_SESSION_STORAGE_KEY = 'setChatState'; +export const DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY = 'disabledMcpServerIds'; +export const MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY = 'mcpOAuthDraft'; diff --git a/packages/web/src/features/chat/mcpOAuthDraft.test.ts b/packages/web/src/features/chat/mcpOAuthDraft.test.ts new file mode 100644 index 000000000..93c9281c3 --- /dev/null +++ b/packages/web/src/features/chat/mcpOAuthDraft.test.ts @@ -0,0 +1,84 @@ +import { beforeEach, describe, expect, test } from 'vitest'; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from './constants'; +import { + consumeMcpOAuthDraftForPath, + normalizeMcpOAuthDraftPath, + resolveMcpOAuthDraftForPath, + saveMcpOAuthDraft, +} from './mcpOAuthDraft'; +import type { Descendant } from 'slate'; +import type { SearchScope } from './types'; + +const children = [{ + type: 'paragraph', + children: [{ text: 'check the Linear ticket' }], +}] satisfies Descendant[]; + +const selectedSearchScopes = [{ + type: 'repo', + value: 'sourcebot/sourcebot', + name: 'sourcebot/sourcebot', + codeHostType: 'github', +}] satisfies SearchScope[]; + +const draft = { + returnTo: '/chat/thread-1?scope=sourcebot', + children, + selectedSearchScopes, + disabledMcpServerIds: ['server-disabled'], + createdAt: 100, +}; + +describe('MCP OAuth draft persistence', () => { + beforeEach(() => { + sessionStorage.clear(); + }); + + test('normalizes chat paths and strips OAuth status params', () => { + expect(normalizeMcpOAuthDraftPath('/chat/thread-1?scope=sourcebot&status=connected&server=Linear')).toBe('/chat/thread-1?scope=sourcebot'); + expect(normalizeMcpOAuthDraftPath('/settings/mcpServers')).toBeUndefined(); + expect(normalizeMcpOAuthDraftPath('https://evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthDraftPath('//evil.example.com/chat')).toBeUndefined(); + }); + + test('resolves a draft for the same chat path after the OAuth callback adds status params', () => { + const result = resolveMcpOAuthDraftForPath( + JSON.stringify(draft), + '/chat/thread-1?scope=sourcebot&status=connected&server=Linear', + 200, + ); + + expect(result.shouldClear).toBe(true); + expect(result.draft).toEqual(draft); + }); + + test('keeps a draft when the current chat path does not match', () => { + const result = resolveMcpOAuthDraftForPath(JSON.stringify(draft), '/chat/thread-2', 200); + + expect(result.shouldClear).toBe(false); + expect(result.draft).toBeUndefined(); + }); + + test('clears invalid and stale drafts', () => { + expect(resolveMcpOAuthDraftForPath('{', '/chat/thread-1').shouldClear).toBe(true); + expect(resolveMcpOAuthDraftForPath(JSON.stringify({ ...draft, children: [1] }), '/chat/thread-1?scope=sourcebot', 200).shouldClear).toBe(true); + expect(resolveMcpOAuthDraftForPath(JSON.stringify(draft), '/chat/thread-1?scope=sourcebot', 30 * 60 * 1000 + 101).shouldClear).toBe(true); + }); + + test('saves and consumes the composer draft from sessionStorage', () => { + saveMcpOAuthDraft({ + returnTo: '/chat/thread-1?scope=sourcebot&status=error', + children, + selectedSearchScopes, + disabledMcpServerIds: ['server-disabled'], + }); + + const restoredDraft = consumeMcpOAuthDraftForPath('/chat/thread-1?scope=sourcebot&status=connected&server=Linear'); + + expect(restoredDraft?.returnTo).toBe('/chat/thread-1?scope=sourcebot'); + expect(restoredDraft?.children).toEqual(children); + expect(restoredDraft?.selectedSearchScopes).toEqual(selectedSearchScopes); + expect(restoredDraft?.disabledMcpServerIds).toEqual(['server-disabled']); + expect(sessionStorage.getItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY)).toBeNull(); + }); +}); diff --git a/packages/web/src/features/chat/mcpOAuthDraft.ts b/packages/web/src/features/chat/mcpOAuthDraft.ts new file mode 100644 index 000000000..19f00f84f --- /dev/null +++ b/packages/web/src/features/chat/mcpOAuthDraft.ts @@ -0,0 +1,217 @@ +import type { Descendant } from "slate"; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from "./constants"; +import type { CustomText, MentionElement, ParagraphElement, SearchScope } from "./types"; + +const MCP_OAUTH_DRAFT_BASE_URL = 'https://sourcebot.local'; +const MCP_OAUTH_DRAFT_MAX_AGE_MS = 30 * 60 * 1000; +const MCP_OAUTH_STATUS_PARAMS = ['status', 'server', 'message']; + +export interface McpOAuthDraft { + returnTo: string; + children: Descendant[]; + selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; + createdAt: number; +} + +type McpOAuthDraftInput = Omit; + +interface ResolveMcpOAuthDraftResult { + draft?: McpOAuthDraft; + shouldClear: boolean; +} + +function isAllowedMcpOAuthDraftPath(pathname: string): boolean { + return pathname === '/chat' || pathname.startsWith('/chat/'); +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function isCustomText(value: unknown): value is CustomText { + return isRecord(value) && typeof value.text === 'string'; +} + +function isMentionElement(value: unknown): value is MentionElement { + return ( + isRecord(value) && + value.type === 'mention' && + isRecord(value.data) && + value.data.type === 'file' && + typeof value.data.repo === 'string' && + typeof value.data.path === 'string' && + typeof value.data.name === 'string' && + typeof value.data.language === 'string' && + typeof value.data.revision === 'string' && + Array.isArray(value.children) && + value.children.every(isCustomText) + ); +} + +function isParagraphElement(value: unknown): value is ParagraphElement { + return ( + isRecord(value) && + value.type === 'paragraph' && + (value.align === undefined || typeof value.align === 'string') && + Array.isArray(value.children) && + value.children.length > 0 && + value.children.every((child) => isCustomText(child) || isMentionElement(child)) + ); +} + +function isMcpOAuthDraftChildren(value: unknown): value is Descendant[] { + return Array.isArray(value) && value.length > 0 && value.every(isParagraphElement); +} + +export function normalizeMcpOAuthDraftPath(path: string): string | undefined { + const trimmedPath = path.trim(); + if (!trimmedPath || !trimmedPath.startsWith('/') || trimmedPath.startsWith('//') || trimmedPath.includes('\\')) { + return undefined; + } + + try { + const url = new URL(trimmedPath, MCP_OAUTH_DRAFT_BASE_URL); + if (url.origin !== MCP_OAUTH_DRAFT_BASE_URL || !isAllowedMcpOAuthDraftPath(url.pathname)) { + return undefined; + } + + for (const param of MCP_OAUTH_STATUS_PARAMS) { + url.searchParams.delete(param); + } + + const query = url.searchParams.toString(); + return `${url.pathname}${query ? `?${query}` : ''}`; + } catch { + return undefined; + } +} + +export function createMcpOAuthDraftPath(pathname: string, search: string): string | undefined { + return normalizeMcpOAuthDraftPath(`${pathname}${search}`); +} + +function isMcpOAuthDraft(value: unknown): value is McpOAuthDraft { + return ( + isRecord(value) && + 'returnTo' in value && + typeof value.returnTo === 'string' && + 'children' in value && + isMcpOAuthDraftChildren(value.children) && + 'selectedSearchScopes' in value && + Array.isArray(value.selectedSearchScopes) && + 'disabledMcpServerIds' in value && + Array.isArray(value.disabledMcpServerIds) && + value.disabledMcpServerIds.every((id) => typeof id === 'string') && + 'createdAt' in value && + typeof value.createdAt === 'number' + ); +} + +export function resolveMcpOAuthDraftForPath( + storedDraft: string | null, + currentPath: string, + now = Date.now(), +): ResolveMcpOAuthDraftResult { + if (!storedDraft) { + return { shouldClear: false }; + } + + let parsedDraft: unknown; + try { + parsedDraft = JSON.parse(storedDraft); + } catch { + return { shouldClear: true }; + } + + if (!isMcpOAuthDraft(parsedDraft)) { + return { shouldClear: true }; + } + + if (now - parsedDraft.createdAt > MCP_OAUTH_DRAFT_MAX_AGE_MS) { + return { shouldClear: true }; + } + + const storedPath = normalizeMcpOAuthDraftPath(parsedDraft.returnTo); + if (!storedPath) { + return { shouldClear: true }; + } + + const normalizedCurrentPath = normalizeMcpOAuthDraftPath(currentPath); + if (!normalizedCurrentPath) { + return { shouldClear: false }; + } + + if (storedPath !== normalizedCurrentPath) { + return { shouldClear: false }; + } + + return { + draft: { + ...parsedDraft, + returnTo: storedPath, + }, + shouldClear: true, + }; +} + +function getSessionStorage(): Storage | undefined { + if (typeof window === 'undefined') { + return undefined; + } + + try { + return window.sessionStorage; + } catch { + return undefined; + } +} + +export function saveMcpOAuthDraft(draft: McpOAuthDraftInput): void { + const storage = getSessionStorage(); + const returnTo = normalizeMcpOAuthDraftPath(draft.returnTo); + if (!storage || !returnTo) { + return; + } + + try { + storage.setItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY, JSON.stringify({ + ...draft, + returnTo, + createdAt: Date.now(), + } satisfies McpOAuthDraft)); + } catch { + // If sessionStorage is unavailable or full, OAuth should still proceed. + } +} + +export function clearMcpOAuthDraft(): void { + const storage = getSessionStorage(); + if (!storage) { + return; + } + + try { + storage.removeItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY); + } catch { + // Ignore storage cleanup failures. + } +} + +export function consumeMcpOAuthDraftForPath(currentPath: string): McpOAuthDraft | undefined { + const storage = getSessionStorage(); + if (!storage) { + return undefined; + } + + const result = resolveMcpOAuthDraftForPath( + storage.getItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY), + currentPath, + ); + + if (result.shouldClear) { + clearMcpOAuthDraft(); + } + + return result.draft; +} diff --git a/packages/web/src/features/chat/mcpServerIconContext.tsx b/packages/web/src/features/chat/mcpServerIconContext.tsx new file mode 100644 index 000000000..94628f4a5 --- /dev/null +++ b/packages/web/src/features/chat/mcpServerIconContext.tsx @@ -0,0 +1,10 @@ +'use client'; + +import { createContext, useContext } from 'react'; + +// Maps sanitized server name (e.g. "linear") to a favicon URL. +export type McpServerIconMap = Record; + +export const McpServerIconContext = createContext({}); + +export const useMcpServerIconMap = () => useContext(McpServerIconContext); diff --git a/packages/web/src/features/chat/toolApprovalContext.tsx b/packages/web/src/features/chat/toolApprovalContext.tsx new file mode 100644 index 000000000..d4379c394 --- /dev/null +++ b/packages/web/src/features/chat/toolApprovalContext.tsx @@ -0,0 +1,9 @@ +'use client'; + +import { createContext, useContext } from 'react'; +import type { ChatAddToolApproveResponseFunction } from 'ai'; + +const ToolApprovalContext = createContext(null); + +export const ToolApprovalProvider = ToolApprovalContext.Provider; +export const useToolApproval = () => useContext(ToolApprovalContext); \ No newline at end of file diff --git a/packages/web/src/features/chat/types.test.ts b/packages/web/src/features/chat/types.test.ts new file mode 100644 index 000000000..a9f41df7c --- /dev/null +++ b/packages/web/src/features/chat/types.test.ts @@ -0,0 +1,72 @@ +import { expect, test, describe } from 'vitest'; +import { sbChatMessageMetadataSchema, additionalChatRequestParamsSchema } from './types'; + +describe('sbChatMessageMetadataSchema', () => { + test('accepts disabledMcpServerIds as array of strings', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: ['id1', 'id2'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['id1', 'id2']); + } + }); + + test('accepts missing disabledMcpServerIds (optional)', () => { + const result = sbChatMessageMetadataSchema.safeParse({}); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toBeUndefined(); + } + }); + + test('rejects non-string array values', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: [123, 456], + }); + + expect(result.success).toBe(false); + }); +}); + +describe('additionalChatRequestParamsSchema', () => { + const validBase = { + languageModel: { + provider: 'anthropic', + model: 'claude-sonnet-4-20250514', + }, + selectedSearchScopes: [], + }; + + test('defaults disabledMcpServerIds to empty array', () => { + const result = additionalChatRequestParamsSchema.safeParse(validBase); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual([]); + } + }); + + test('accepts explicit disabledMcpServerIds array', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: ['abc'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['abc']); + } + }); + + test('rejects non-array value for disabledMcpServerIds', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: 'not-an-array', + }); + + expect(result.success).toBe(false); + }); +}); diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index 6e990f5c2..3c2619f14 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -60,6 +60,7 @@ export const sbChatMessageMetadataSchema = z.object({ userId: z.string().optional(), })).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(), + disabledMcpServerIds: z.array(z.string()).optional(), traceId: z.string().optional(), }); @@ -67,12 +68,22 @@ export type SBChatMessageMetadata = z.infer; export type SBChatMessageToolTypes = { [K in keyof ReturnType]: InferUITool[K]>; +} & { + tool_request_activation: { + input: { tool_to_activate_name: string }; + output: { results: Array<{ name: string; description: string }> }; + }; }; export type SBChatMessageDataParts = { // The `source` data type allows us to know what sources the LLM saw // during retrieval. "source": Source, + // The `mcp-server` data type carries favicon metadata for connected MCP servers, + // keyed by sanitized server name (e.g. "linear"). + "mcp-server": { sanitizedName: string; faviconUrl: string }, + // The `mcp-failed-server` data type surfaces MCP servers that failed to load their tools. + "mcp-failed-server": { serverName: string }, } export type SBChatMessage = UIMessage< @@ -143,6 +154,7 @@ declare module 'slate' { export type SetChatStatePayload = { inputMessage: CreateUIMessage; selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; } @@ -188,5 +200,6 @@ export type LanguageModelInfo = { export const additionalChatRequestParamsSchema = z.object({ languageModel: languageModelInfoSchema, selectedSearchScopes: z.array(searchScopeSchema), + disabledMcpServerIds: z.array(z.string()).default([]), }); -export type AdditionalChatRequestParams = z.infer; \ No newline at end of file +export type AdditionalChatRequestParams = z.infer; diff --git a/packages/web/src/features/chat/useCreateNewChatThread.ts b/packages/web/src/features/chat/useCreateNewChatThread.ts index 63ead0249..7cf72a0ce 100644 --- a/packages/web/src/features/chat/useCreateNewChatThread.ts +++ b/packages/web/src/features/chat/useCreateNewChatThread.ts @@ -30,11 +30,11 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew const hasRestoredPendingMessage = useRef(false); const captureEvent = useCaptureEvent(); - const doCreateChat = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[]) => { + const doCreateChat = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[]) => { const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); + const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes, disabledMcpServerIds); setIsLoading(true); const response = await createChat({ source: 'sourcebot-web-client' }); @@ -49,6 +49,7 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew setChatState({ inputMessage, selectedSearchScopes, + disabledMcpServerIds, }); const url = createPathWithQueryParams(`/chat/${response.id}`); @@ -56,18 +57,18 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew router.push(url); }, [router, toast, setChatState]); - const createNewChatThread = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[]) => { + const createNewChatThread = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[]) => { if (!isAuthenticated) { const result = await getAskGhLoginWallData(); if (!isServiceError(result) && result.isEnabled) { captureEvent('wa_askgh_login_wall_prompted', {}); - sessionStorage.setItem(PENDING_NEW_CHAT_KEY, JSON.stringify({ children, selectedSearchScopes })); + sessionStorage.setItem(PENDING_NEW_CHAT_KEY, JSON.stringify({ children, selectedSearchScopes, disabledMcpServerIds })); setLoginWallState({ isOpen: true, providers: result.providers }); return; } } - doCreateChat(children, selectedSearchScopes); + doCreateChat(children, selectedSearchScopes, disabledMcpServerIds); }, [isAuthenticated, captureEvent, doCreateChat]); // Restore pending message after OAuth redirect @@ -85,11 +86,12 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew sessionStorage.removeItem(PENDING_NEW_CHAT_KEY); try { - const { children, selectedSearchScopes } = JSON.parse(stored) as { + const { children, selectedSearchScopes, disabledMcpServerIds } = JSON.parse(stored) as { children: Descendant[]; selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; }; - doCreateChat(children, selectedSearchScopes); + doCreateChat(children, selectedSearchScopes, disabledMcpServerIds ?? []); } catch (error) { console.error('Failed to restore pending message:', error); } diff --git a/packages/web/src/features/chat/utils.test.ts b/packages/web/src/features/chat/utils.test.ts index 26359d2a9..e5a89c0bb 100644 --- a/packages/web/src/features/chat/utils.test.ts +++ b/packages/web/src/features/chat/utils.test.ts @@ -1,5 +1,5 @@ -import { expect, test, vi } from 'vitest' -import { fileReferenceToString, getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from './utils' +import { expect, test, describe, vi } from 'vitest' +import { createUIMessage, fileReferenceToString, getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from './utils' import { FILE_REFERENCE_REGEX, ANSWER_TAG } from './constants'; import { SBChatMessage, SBChatMessagePart } from './types'; @@ -351,3 +351,31 @@ test('repairReferences handles malformed inline code blocks', () => { const expected = 'See @file:{github.com/sourcebot-dev/sourcebot::packages/web/src/auth.ts} for details.'; expect(repairReferences(input)).toBe(expected); }); + +describe('createUIMessage', () => { + test('includes disabledMcpServerIds in metadata when provided', () => { + const result = createUIMessage('hello', [], [], ['server1', 'server2']); + + expect(result.metadata?.disabledMcpServerIds).toEqual(['server1', 'server2']); + }); + + test('defaults disabledMcpServerIds to empty array when omitted', () => { + const result = createUIMessage('hello', [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('passes through empty array', () => { + const result = createUIMessage('hello', [], [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('includes both selectedSearchScopes and disabledMcpServerIds in metadata', () => { + const scopes = [{ type: 'repo' as const, value: 'org/repo', name: 'repo', codeHostType: 'github' }]; + const result = createUIMessage('hello', [], scopes, ['disabled1']); + + expect(result.metadata?.selectedSearchScopes).toEqual(scopes); + expect(result.metadata?.disabledMcpServerIds).toEqual(['disabled1']); + }); +}); diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index 38dd784fd..cdcd1c0e0 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -161,11 +161,16 @@ export const getAllMentionElements = (children: Descendant[]): MentionElement[] }); } +export const clearEditorHistory = (editor: CustomEditor) => { + // slate-history exposes `history` publicly, but does not provide a clear API. + editor.history = { redos: [], undos: [] }; +} + // @see: https://stackoverflow.com/a/74102147 export const resetEditor = (editor: CustomEditor) => { const point = { path: [0, 0], offset: 0 } editor.selection = { anchor: point, focus: point }; - editor.history = { redos: [], undos: [] }; + clearEditorHistory(editor); editor.children = [{ type: "paragraph", children: [{ text: "" }] @@ -176,7 +181,7 @@ export const addLineNumbers = (source: string, lineOffset = 1) => { return source.split('\n').map((line, index) => `${index + lineOffset}: ${line}`).join('\n'); } -export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[]): CreateUIMessage => { +export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[] = []): CreateUIMessage => { // Converts applicable mentions into sources. const sources: Source[] = mentions .map((mention) => { @@ -209,6 +214,7 @@ export const createUIMessage = (text: string, mentions: MentionData[], selectedS ], metadata: { selectedSearchScopes, + disabledMcpServerIds, }, } } diff --git a/packages/web/src/features/mcp/askCodebase.ts b/packages/web/src/features/mcp/askCodebase.ts index bc3a030c2..2c8186b96 100644 --- a/packages/web/src/features/mcp/askCodebase.ts +++ b/packages/web/src/features/mcp/askCodebase.ts @@ -155,6 +155,7 @@ export const askCodebase = (params: AskCodebaseParams): Promise r.value), + prisma, model, modelName, modelProviderOptions: providerOptions, diff --git a/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts b/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts new file mode 100644 index 000000000..321d9ee3d --- /dev/null +++ b/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts @@ -0,0 +1,32 @@ +import { describe, expect, test } from 'vitest'; +import { + createMcpOAuthState, + getMcpOAuthReturnToFromState, + normalizeMcpOAuthReturnTo, +} from './mcpOAuthReturnTo'; + +describe('MCP OAuth return paths', () => { + test('allows chat return paths', () => { + expect(normalizeMcpOAuthReturnTo('/chat')).toBe('/chat'); + expect(normalizeMcpOAuthReturnTo('/chat/thread-1?foo=bar')).toBe('/chat/thread-1?foo=bar'); + }); + + test('rejects external and unrelated return paths', () => { + expect(normalizeMcpOAuthReturnTo('https://evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('//evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('/settings')).toBeUndefined(); + }); + + test('encodes and decodes return paths inside OAuth state', () => { + const state = createMcpOAuthState('nonce-1', '/chat'); + + expect(state).not.toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState(state)).toBe('/chat'); + }); + + test('leaves state unchanged when no valid return path exists', () => { + expect(createMcpOAuthState('nonce-1')).toBe('nonce-1'); + expect(createMcpOAuthState('nonce-1', '/settings')).toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState('nonce-1')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/features/mcp/mcpOAuthReturnTo.ts b/packages/web/src/features/mcp/mcpOAuthReturnTo.ts new file mode 100644 index 000000000..e46b5805e --- /dev/null +++ b/packages/web/src/features/mcp/mcpOAuthReturnTo.ts @@ -0,0 +1,63 @@ +const MCP_OAUTH_STATE_PREFIX = 'sourcebot_mcp.'; +const MCP_OAUTH_STATE_BASE_URL = 'https://sourcebot.local'; + +function isAllowedMcpOAuthReturnPath(pathname: string): boolean { + return pathname === '/chat' || pathname.startsWith('/chat/') || pathname === '/settings/mcpServers'; +} + +export function normalizeMcpOAuthReturnTo(returnTo: unknown): string | undefined { + if (typeof returnTo !== 'string') { + return undefined; + } + + const trimmedReturnTo = returnTo.trim(); + if (!trimmedReturnTo || !trimmedReturnTo.startsWith('/') || trimmedReturnTo.startsWith('//') || trimmedReturnTo.includes('\\')) { + return undefined; + } + + try { + const url = new URL(trimmedReturnTo, MCP_OAUTH_STATE_BASE_URL); + if (url.origin !== MCP_OAUTH_STATE_BASE_URL || !isAllowedMcpOAuthReturnPath(url.pathname)) { + return undefined; + } + + return `${url.pathname}${url.search}`; + } catch { + return undefined; + } +} + +export function createMcpOAuthState(nonce: string, returnTo?: string): string { + const normalizedReturnTo = normalizeMcpOAuthReturnTo(returnTo); + if (!normalizedReturnTo) { + return nonce; + } + + const encoded = Buffer.from(JSON.stringify({ + nonce, + returnTo: normalizedReturnTo, + })).toString('base64url'); + return `${MCP_OAUTH_STATE_PREFIX}${encoded}`; +} + +export function getMcpOAuthReturnToFromState(state: string | null | undefined): string | undefined { + if (!state?.startsWith(MCP_OAUTH_STATE_PREFIX)) { + return undefined; + } + + try { + const encoded = state.slice(MCP_OAUTH_STATE_PREFIX.length); + const payload = JSON.parse(Buffer.from(encoded, 'base64url').toString('utf8')) as unknown; + if ( + typeof payload === 'object' && + payload !== null && + 'returnTo' in payload + ) { + return normalizeMcpOAuthReturnTo(payload.returnTo); + } + } catch { + return undefined; + } + + return undefined; +} diff --git a/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts b/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts new file mode 100644 index 000000000..5e8d77084 --- /dev/null +++ b/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts @@ -0,0 +1,179 @@ +import { describe, expect, test, vi, beforeEach } from 'vitest'; +import { McpServerClientInfoSource } from '@sourcebot/db'; + +vi.mock('server-only', () => ({})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: { + mcpServer: {}, + userMcpServer: {}, + }, +})); +vi.mock('@sourcebot/shared', () => ({ + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), +})); + +const { + PrismaOAuthClientProvider, + clearMcpServerClientCredentialsForObservedClient, +} = await import('./prismaOAuthClientProvider'); + +function createPrismaMock() { + return { + mcpServer: { + findFirst: vi.fn(), + updateMany: vi.fn(), + }, + userMcpServer: { + findUnique: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + }; +} + +function createProvider(prisma = createPrismaMock(), allowClientRegistration = false) { + return new PrismaOAuthClientProvider({ + prisma: prisma as never, + clientInvalidationPrisma: prisma as never, + serverId: 'server-1', + orgId: 1, + userId: 'user-1', + callbackUrl: 'https://sourcebot.example.com/api/ee/askmcp/callback', + allowClientRegistration, + }); +} + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('PrismaOAuthClientProvider modes', () => { + test('connect-mode provider exposes saveClientInformation', () => { + const provider = createProvider(createPrismaMock(), true); + + expect('saveClientInformation' in provider).toBe(true); + expect(provider.saveClientInformation).toEqual(expect.any(Function)); + }); + + test('runtime and callback providers omit saveClientInformation', () => { + const provider = createProvider(); + + expect('saveClientInformation' in provider).toBe(false); + expect(provider.saveClientInformation).toBeUndefined(); + }); +}); + +describe('clearMcpServerClientCredentialsForObservedClient', () => { + test('matching observed clientInfo clears org clientInfo and all server tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 1 }); + prisma.userMcpServer.updateMany.mockResolvedValue({ count: 2 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'encrypted-client-info', + }); + + expect(didClear).toBe(true); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted-client-info', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).toHaveBeenCalledWith({ + where: { + serverId: 'server-1', + server: { orgId: 1 }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + }); + + test('stale observed clientInfo clears neither org clientInfo nor tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'stale-client-info', + }); + + expect(didClear).toBe(false); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledOnce(); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + }); +}); + +describe('PrismaOAuthClientProvider static client information', () => { + test('clientInformation returns static OAuth client credentials', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + const provider = createProvider(prisma); + + await expect(provider.clientInformation()).resolves.toEqual({ + client_id: 'client-id', + client_secret: 'client-secret', + }); + }); + + test('invalidate all preserves static client information and clears only the current user tokens and verifier', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.clientInformation(); + await provider.invalidateCredentials('all'); + + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + }); +}); diff --git a/packages/web/src/features/mcp/prismaOAuthClientProvider.ts b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts new file mode 100644 index 000000000..3f5446b40 --- /dev/null +++ b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts @@ -0,0 +1,291 @@ +import 'server-only'; +import type { + OAuthClientProvider, + OAuthClientInformation, + OAuthClientMetadata, + OAuthTokens, +} from '@ai-sdk/mcp'; +import { McpServerClientInfoSource, type PrismaClient } from '@sourcebot/db'; +import { encryptOAuthToken, decryptOAuthToken } from '@sourcebot/shared'; +import { __unsafePrisma } from '@/prisma'; +import { createMcpOAuthState } from './mcpOAuthReturnTo'; + +type McpOAuthPrismaClient = Pick; + +interface PrismaOAuthClientProviderOptions { + prisma: McpOAuthPrismaClient; + serverId: string; + orgId: number; + userId: string; + callbackUrl: string; + callbackReturnTo?: string; + allowClientRegistration?: boolean; + clientInvalidationPrisma?: McpOAuthPrismaClient; +} + +export interface ClearMcpServerClientCredentialsOptions { + prisma?: McpOAuthPrismaClient; + serverId: string; + orgId: number; + observedClientInfo: string | undefined; +} + +export async function clearMcpServerClientCredentialsForObservedClient({ + prisma = __unsafePrisma, + serverId, + orgId, + observedClientInfo, +}: ClearMcpServerClientCredentialsOptions): Promise { + if (!observedClientInfo) { + return false; + } + + const result = await prisma.mcpServer.updateMany({ + where: { + id: serverId, + orgId, + clientInfo: observedClientInfo, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + + if (result.count === 0) { + return false; + } + + await prisma.userMcpServer.updateMany({ + where: { + serverId, + server: { orgId }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + + return true; +} + +/** + * Prisma-backed OAuthClientProvider for connecting to external MCP servers. + * + * Stores dynamic client registration on McpServer (per-org), and per-user + * tokens + ephemeral PKCE state on UserMcpServer. + */ +export class PrismaOAuthClientProvider implements OAuthClientProvider { + private readonly prisma: McpOAuthPrismaClient; + private readonly clientInvalidationPrisma: McpOAuthPrismaClient; + private readonly serverId: string; + private readonly orgId: number; + private readonly userId: string; + private readonly callbackUrl: string; + private readonly callbackReturnTo: string | undefined; + private observedClientInfo: string | undefined; + private observedClientInfoSource: McpServerClientInfoSource | undefined; + + /** Populated by redirectToAuthorization — read after auth() returns 'REDIRECT'. */ + public authorizationUrl: string | undefined; + + /** Only present in connect mode. If absent, the SDK cannot perform DCR. */ + declare saveClientInformation?: (info: OAuthClientInformation) => Promise; + + constructor({ + prisma, + serverId, + orgId, + userId, + callbackUrl, + callbackReturnTo, + allowClientRegistration = false, + clientInvalidationPrisma = __unsafePrisma, + }: PrismaOAuthClientProviderOptions) { + this.prisma = prisma; + this.clientInvalidationPrisma = clientInvalidationPrisma; + this.serverId = serverId; + this.orgId = orgId; + this.userId = userId; + this.callbackUrl = callbackUrl; + this.callbackReturnTo = callbackReturnTo; + + if (allowClientRegistration) { + this.saveClientInformation = async (info: OAuthClientInformation) => { + const encrypted = encryptOAuthToken(JSON.stringify(info)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth client information'); + } + + const result = await this.prisma.mcpServer.updateMany({ + where: { id: this.serverId, orgId: this.orgId }, + data: { + clientInfo: encrypted, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + if (result.count === 0) { + throw new Error('MCP server not found'); + } + + this.observedClientInfo = encrypted; + this.observedClientInfoSource = McpServerClientInfoSource.DYNAMIC; + }; + } + } + + get redirectUrl(): string | URL { + return this.callbackUrl; + } + + get clientMetadata(): OAuthClientMetadata { + return { + redirect_uris: [this.callbackUrl], + client_name: 'Sourcebot', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + } + + async clientInformation(): Promise { + const server = await this.prisma.mcpServer.findFirst({ + where: { id: this.serverId, orgId: this.orgId }, + select: { + clientInfo: true, + clientInfoSource: true, + }, + }); + if (!server?.clientInfo) { + this.observedClientInfo = undefined; + this.observedClientInfoSource = undefined; + return undefined; + } + + this.observedClientInfo = server.clientInfo; + this.observedClientInfoSource = server.clientInfoSource; + const decrypted = decryptOAuthToken(server.clientInfo); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async tokens(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.tokens) { + return undefined; + } + + const decrypted = decryptOAuthToken(userServer.tokens); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async saveTokens(tokens: OAuthTokens): Promise { + const encrypted = encryptOAuthToken(JSON.stringify(tokens)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth tokens'); + } + + const tokensExpiresAt = tokens.expires_in + ? new Date(Date.now() + tokens.expires_in * 1000) + : null; + await this.updateUserServer({ tokens: encrypted, tokensExpiresAt }); + } + + async codeVerifier(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.codeVerifier) { + throw new Error('No code verifier found'); + } + return userServer.codeVerifier; + } + + async saveCodeVerifier(codeVerifier: string): Promise { + await this.updateUserServer({ codeVerifier }); + } + + async state(): Promise { + return createMcpOAuthState(crypto.randomUUID(), this.callbackReturnTo); + } + + async saveState(state: string): Promise { + await this.updateUserServer({ state }); + } + + async storedState(): Promise { + const userServer = await this.getUserServer(); + return userServer?.state ?? undefined; + } + + async redirectToAuthorization(url: URL): Promise { + // Force the OAuth provider to show a consent/login screen on every authorization. + // This prevents a stolen-session attack where an attacker signs into Sourcebot on + // a victim's machine and silently obtains the victim's provider tokens via an + // existing browser session. + if (!url.searchParams.has('prompt')) { + url.searchParams.set('prompt', 'consent'); + } + + // Clear stale tokens before starting a new authorization flow so the UI reflects + // that the user needs to complete OAuth again. + await this.invalidateCredentials('tokens'); + + this.authorizationUrl = url.toString(); + } + + async invalidateCredentials( + scope: 'all' | 'client' | 'tokens' | 'verifier' | 'discovery', + ): Promise { + if (scope === 'discovery') { + return; + } + + if (scope === 'all' || scope === 'client') { + const didClearDynamicClient = await clearMcpServerClientCredentialsForObservedClient({ + prisma: this.clientInvalidationPrisma, + serverId: this.serverId, + orgId: this.orgId, + observedClientInfo: this.observedClientInfo, + }); + if ( + scope === 'all' && + !didClearDynamicClient && + this.observedClientInfoSource === McpServerClientInfoSource.STATIC + ) { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + } + + if (scope === 'tokens') { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + + if (scope === 'all' || scope === 'verifier') { + await this.updateUserServer({ codeVerifier: null, state: null }); + } + } + + private async getUserServer() { + return this.prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + select: { + tokens: true, + codeVerifier: true, + state: true, + }, + }); + } + + private async updateUserServer(data: { + tokens?: string | null; + tokensExpiresAt?: Date | null; + codeVerifier?: string | null; + state?: string | null; + }) { + await this.prisma.userMcpServer.update({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + data, + }); + } +} diff --git a/packages/web/src/features/mcp/prismaScope.test.ts b/packages/web/src/features/mcp/prismaScope.test.ts new file mode 100644 index 000000000..4b86264db --- /dev/null +++ b/packages/web/src/features/mcp/prismaScope.test.ts @@ -0,0 +1,443 @@ +import { describe, expect, test, vi } from 'vitest'; +import type { UserWithAccounts } from '@sourcebot/db'; +import { getMcpPrismaQueryExtension, scopeUserMcpServerWhere } from './prismaScope'; + +const user = { + id: 'user-1', + name: 'Test User', + email: 'test@example.com', + hashedPassword: null, + emailVerified: null, + image: null, + sessionVersion: 0, + createdAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + accounts: [], +} satisfies UserWithAccounts; + +const callQuery = vi.fn(async (args: unknown) => args); + +const resetQuery = () => { + callQuery.mockClear(); + return callQuery; +}; + +const callAllOperations = ( + model: { + $allOperations: (params: { + operation: string; + args: unknown; + query: (args: unknown) => Promise; + }) => Promise; + }, + operation: string, + args: unknown, + query = resetQuery(), +) => model.$allOperations({ operation, args, query }); + +describe('scopeUserMcpServerWhere', () => { + test('merges existing filters with the authenticated user id', () => { + expect(scopeUserMcpServerWhere({ tokens: { not: null } }, user)).toEqual({ + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }); + }); + + test('fails closed for anonymous users', () => { + expect(scopeUserMcpServerWhere(undefined, undefined)).toEqual({ + AND: [ + { userId: '__sourcebot_anonymous_user__' }, + { userId: '__sourcebot_no_authenticated_user__' }, + ], + }); + }); +}); + +describe('getMcpPrismaQueryExtension', () => { + test('scopes list-style UserMcpServer reads', async () => { + const extension = getMcpPrismaQueryExtension(user); + const result = await extension.userMcpServer.findMany({ + args: { where: { tokens: { not: null } } }, + query: resetQuery(), + }); + + expect(result).toEqual({ + where: { + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }, + }); + }); + + test('returns null for anonymous or mismatched findUnique queries', async () => { + const anonymousExtension = getMcpPrismaQueryExtension(); + const mismatchedExtension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(anonymousExtension.userMcpServer.findUnique({ + args: { where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } } }, + query, + })).resolves.toBeNull(); + await expect(mismatchedExtension.userMcpServer.findUnique({ + args: { where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } } }, + query, + })).resolves.toBeNull(); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows matching findUnique queries through', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } } }; + + await expect(extension.userMcpServer.findUnique({ + args, + query: resetQuery(), + })).resolves.toBe(args); + }); + + test('rejects creates for anonymous or mismatched users', async () => { + const anonymousExtension = getMcpPrismaQueryExtension(); + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(anonymousExtension.userMcpServer.create({ + args: { data: { userId: 'user-1', serverId: 'server-1' } }, + query, + })).rejects.toThrow('requires an authenticated user'); + await expect(extension.userMcpServer.create({ + args: { data: { userId: 'user-2', serverId: 'server-1' } }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows checked creates that connect the authenticated user', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { + data: { + user: { connect: { id: 'user-1' } }, + server: { connect: { id: 'server-1' } }, + }, + }; + + await expect(extension.userMcpServer.create({ + args, + query: resetQuery(), + })).resolves.toBe(args); + }); + + test('rejects checked creates that do not connect the authenticated user', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.create({ + args: { + data: { + user: { connect: { id: 'user-2' } }, + server: { connect: { id: 'server-1' } }, + }, + }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + await expect(extension.userMcpServer.create({ + args: { + data: { + user: { create: { id: 'user-1', email: 'test@example.com' } }, + server: { connect: { id: 'server-1' } }, + }, + }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects mismatched update/delete composite keys', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } }, + data: { state: null }, + }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows for another user'); + await expect(extension.userMcpServer.delete({ + args: { where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } } }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows for another user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects attempts to mutate UserMcpServer ownership', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + data: { userId: 'user-2' }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + data: { server: { connect: { id: 'server-2' } } }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + await expect(extension.userMcpServer.upsert({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + create: { userId: 'user-1', serverId: 'server-1' }, + update: { user: { connect: { id: 'user-2' } } }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + }); + + test('scopes updateMany and deleteMany', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.updateMany({ + args: { where: { tokens: { not: null } }, data: { state: null } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }, + data: { state: null }, + }); + await expect(extension.userMcpServer.deleteMany({ + args: { where: { serverId: 'server-1' } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { serverId: 'server-1' }, + { userId: 'user-1' }, + ], + }, + }); + }); + + test('scopes returning bulk UserMcpServer operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.createManyAndReturn({ + args: { data: { userId: 'user-2', serverId: 'server-1' } }, + query: resetQuery(), + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + await expect(extension.userMcpServer.updateManyAndReturn({ + args: { where: { serverId: 'server-1' }, data: { state: null } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { serverId: 'server-1' }, + { userId: 'user-1' }, + ], + }, + data: { state: null }, + }); + }); + + test('rejects nested UserMcpServer relation access through direct UserMcpServer queries', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.findMany({ + args: { + include: { + server: { + include: { + userMcpServers: true, + }, + }, + }, + }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects nested UserMcpServer writes through McpServer operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'update', + { + where: { id: 'server-1' }, + data: { userMcpServers: { create: { userId: 'user-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects nested UserMcpServer reads and writes through parent models', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'findUnique', + { + where: { id: 'server-1' }, + include: { userMcpServers: true }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + await expect(callAllOperations( + extension.user, + 'findMany', + { + where: { userMcpServers: { some: { serverId: 'server-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + await expect(callAllOperations( + extension.user, + 'update', + { + where: { id: 'user-1' }, + data: { userMcpServers: { create: { serverId: 'server-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects transitive MCP relation access through Org and UserToOrg operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.org, + 'findUnique', + { + where: { id: 1 }, + include: { + mcpServers: { + include: { + userMcpServers: true, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + await expect(callAllOperations( + extension.org, + 'update', + { + where: { id: 1 }, + data: { + mcpServers: { + create: { + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + userMcpServers: { + create: { userId: 'user-1' }, + }, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + await expect(callAllOperations( + extension.userToOrg, + 'findMany', + { + include: { + org: { + include: { + mcpServers: { + include: { + userMcpServers: true, + }, + }, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows JSON metadata payloads with relation-like keys', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { + where: { id: 1 }, + data: { + metadata: { + mcpServers: 'display-state', + userMcpServers: { collapsed: true }, + }, + }, + }; + + await expect(callAllOperations(extension.org, 'update', args)).resolves.toBe(args); + }); + + test('passes safe parent-model operations through the compact hooks', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { orgId: 1 } }; + + await expect(callAllOperations(extension.userToOrg, 'findMany', args)).resolves.toBe(args); + }); + + test('allows single user deletes but blocks bulk user deletes', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { id: 'user-2' } }; + const query = resetQuery(); + + await expect(callAllOperations(extension.user, 'delete', args, query)).resolves.toBe(args); + expect(query).toHaveBeenCalledTimes(1); + query.mockClear(); + + await expect(callAllOperations(extension.user, 'deleteMany', { where: {} }, query)) + .rejects.toThrow('user.deleteMany cannot delete users through a user-scoped client'); + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects shared McpServer deletes through the scoped client', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'delete', + { where: { id: 'server-1' } }, + query, + )).rejects.toThrow('cannot delete shared McpServer rows through a user-scoped client'); + await expect(callAllOperations( + extension.mcpServer, + 'deleteMany', + { where: { orgId: 1 } }, + query, + )).rejects.toThrow('cannot delete shared McpServer rows through a user-scoped client'); + + expect(query).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/web/src/features/mcp/prismaScope.ts b/packages/web/src/features/mcp/prismaScope.ts new file mode 100644 index 000000000..9e3089f24 --- /dev/null +++ b/packages/web/src/features/mcp/prismaScope.ts @@ -0,0 +1,366 @@ +import { Prisma, UserWithAccounts } from '@sourcebot/db'; + +type QueryHookParams = { + args: TArgs; + query: (args: TArgs) => Promise; +}; + +type AllOperationsHookParams = { + operation: string; + args: unknown; + query: (args: unknown) => Promise; +}; + +type UserMcpServerWhereArgs = { + where?: Prisma.UserMcpServerWhereInput; +}; + +type UserMcpServerWhereUniqueArgs = { + where: Prisma.UserMcpServerWhereUniqueInput; +}; + +type UserMcpServerCreateArgs = { + data: unknown; +}; + +type UserMcpServerUpdateArgs = UserMcpServerWhereUniqueArgs & { + data: unknown; +}; + +type UserMcpServerUpsertArgs = UserMcpServerWhereUniqueArgs & { + create: unknown; + update: unknown; +}; + +// Deliberately impossible filter — AND-ing two different userId values guarantees zero rows. +// Used as the fallback when no user is authenticated, so anonymous queries see nothing. +// Prisma doesn't expose a "match nothing" primitive, so this is the standard workaround. +const anonymousUserScope: Prisma.UserMcpServerWhereInput = { + AND: [ + { userId: '__sourcebot_anonymous_user__' }, + { userId: '__sourcebot_no_authenticated_user__' }, + ], +}; + +const isRecord = (value: unknown): value is Record => + typeof value === 'object' && value !== null && !Array.isArray(value); + +const userScopeWhere = (user?: UserWithAccounts): Prisma.UserMcpServerWhereInput => + user ? { userId: user.id } : anonymousUserScope; + +export const scopeUserMcpServerWhere = ( + where: Prisma.UserMcpServerWhereInput | undefined, + user?: UserWithAccounts, +): Prisma.UserMcpServerWhereInput => { + const scope = userScopeWhere(user); + return where ? { AND: [where, scope] } : scope; +}; + +const scopeUserMcpServerReadArgs = ( + args: TArgs, + user?: UserWithAccounts, +): TArgs => ({ + ...args, + where: scopeUserMcpServerWhere(args.where, user), +}); + +const requireAuthenticatedUser = ( + user: UserWithAccounts | undefined, + operation: string, +): UserWithAccounts => { + if (!user) { + throw new Error(`${operation} requires an authenticated user.`); + } + return user; +}; + +const uniqueWhereUserId = (where: Prisma.UserMcpServerWhereUniqueInput): string | undefined => { + const compositeKey = where.userId_serverId; + return isRecord(compositeKey) && typeof compositeKey.userId === 'string' + ? compositeKey.userId + : undefined; +}; + +export const isUserMcpServerUniqueWhereForUser = ( + where: Prisma.UserMcpServerWhereUniqueInput, + user?: UserWithAccounts, +) => !!user && uniqueWhereUserId(where) === user.id; + +const assertUserMcpServerUniqueWhereForUser = ( + where: Prisma.UserMcpServerWhereUniqueInput, + user: UserWithAccounts | undefined, + operation: string, +) => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + if (!isUserMcpServerUniqueWhereForUser(where, authenticatedUser)) { + throw new Error(`${operation} cannot access UserMcpServer rows for another user.`); + } +}; + +const assertNoIdentityMutation = (data: unknown, operation: string) => { + if (!isRecord(data)) { + return; + } + + if ('userId' in data || 'user' in data || 'serverId' in data || 'server' in data) { + throw new Error(`${operation} cannot change UserMcpServer identity.`); + } +}; + +// Extracts the userId from a Prisma relation connect object. +// Prisma's connect syntax for a relation looks like: { connect: { id: "some-id" } } +const connectedUserId = (userRelation: unknown): string | undefined => { + if (!isRecord(userRelation) || !('connect' in userRelation)) { + return undefined; + } + + const connect = userRelation.connect; + if (!isRecord(connect) || !('id' in connect) || typeof connect.id !== 'string') { + return undefined; + } + + return connect.id; +}; + +const createDataUserId = (row: unknown): string | undefined => { + if (!isRecord(row)) { + return undefined; + } + const scalarUserId = typeof row.userId === 'string' ? row.userId : undefined; + const relationUserId = row.user === undefined ? undefined : connectedUserId(row.user); + + if (row.user !== undefined && relationUserId === undefined) { + return undefined; + } + if (scalarUserId !== undefined && relationUserId !== undefined && scalarUserId !== relationUserId) { + return undefined; + } + + return relationUserId ?? scalarUserId; +}; + +const assertCreateDataForUser = ( + data: unknown, + user: UserWithAccounts | undefined, + operation: string, +) => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + + const rows = Array.isArray(data) ? data : [data]; + for (const row of rows) { + if (createDataUserId(row) !== authenticatedUser.id) { + throw new Error(`${operation} must create UserMcpServer rows for the authenticated user.`); + } + } +}; + +const scopeUserMcpServerWriteManyArgs = ( + args: TArgs, + user: UserWithAccounts | undefined, + operation: string, +): TArgs => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + return scopeUserMcpServerReadArgs(args, authenticatedUser); +}; + +const PRISMA_SELECTION_KEYS = new Set(['include', 'select']); +const PRISMA_STRUCTURAL_KEYS = new Set([ + ...PRISMA_SELECTION_KEYS, + 'where', + 'orderBy', + 'data', + 'create', + 'connectOrCreate', + 'update', + 'updateMany', + 'upsert', + 'delete', + 'deleteMany', + 'AND', + 'OR', + 'NOT', + 'some', + 'none', + 'every', + 'is', + 'isNot', +]); +const MCP_RELATION_BRIDGE_KEYS = new Set([ + 'user', + 'server', + 'org', + 'orgs', + 'members', +]); + +const containsPrismaRelationAccess = ( + value: unknown, + relationNames: string[], + isSelectionObject = false, +): boolean => { + if (Array.isArray(value)) { + return value.some((item) => containsPrismaRelationAccess(item, relationNames, isSelectionObject)); + } + if (!isRecord(value)) { + return false; + } + if (relationNames.some((relationName) => relationName in value)) { + return true; + } + + return Object.entries(value).some(([key, nestedValue]) => { + if (PRISMA_SELECTION_KEYS.has(key)) { + return containsPrismaRelationAccess(nestedValue, relationNames, true); + } + + if (isSelectionObject || PRISMA_STRUCTURAL_KEYS.has(key) || MCP_RELATION_BRIDGE_KEYS.has(key)) { + return containsPrismaRelationAccess(nestedValue, relationNames); + } + + return false; + }); +}; + +const assertNoUserMcpServerRelationAccess = (args: unknown, operation: string) => { + if (containsPrismaRelationAccess(args, ['userMcpServers'])) { + throw new Error(`${operation} cannot access UserMcpServer rows through a parent relation.`); + } +}; + +const assertNoMcpServerRelationAccess = (args: unknown, operation: string) => { + if (containsPrismaRelationAccess(args, ['mcpServers', 'userMcpServers'])) { + throw new Error(`${operation} cannot access MCP server relations through a parent relation.`); + } +}; + +const rejectSharedMcpServerDelete = (operation: string) => { + throw new Error(`${operation} cannot delete shared McpServer rows through a user-scoped client.`); +}; + +const rejectUserDeleteMany = () => { + throw new Error('user.deleteMany cannot delete users through a user-scoped client.'); +}; + +const guardMcpParentOperation = ( + modelName: string, + guard: (args: unknown, operation: string) => void, +) => async ({ operation, args, query }: AllOperationsHookParams) => { + guard(args, `${modelName}.${operation}`); + return query(args); +}; + +export const getMcpPrismaQueryExtension = (user?: UserWithAccounts) => ({ + userMcpServer: { + async findMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findMany'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findFirst({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findFirst'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findFirstOrThrow({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findFirstOrThrow'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findUnique({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findUnique'); + // Preserve Prisma's nullable "not found" semantics for scoped reads. Callers that + // need a hard failure should use findUniqueOrThrow; write paths throw on mismatch. + return isUserMcpServerUniqueWhereForUser(args.where, user) ? query(args) : null; + }, + async findUniqueOrThrow({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findUniqueOrThrow'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.findUniqueOrThrow'); + return query(args); + }, + async count({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.count'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async aggregate({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.aggregate'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async groupBy({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.groupBy'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async create({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.create'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.create'); + return query(args); + }, + async createMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.createMany'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.createMany'); + return query(args); + }, + async createManyAndReturn({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.createManyAndReturn'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.createManyAndReturn'); + return query(args); + }, + async update({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.update'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.update'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.update'); + return query(args); + }, + async updateMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.updateMany'); + requireAuthenticatedUser(user, 'userMcpServer.updateMany'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.updateMany'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.updateMany')); + }, + async updateManyAndReturn({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.updateManyAndReturn'); + requireAuthenticatedUser(user, 'userMcpServer.updateManyAndReturn'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.updateManyAndReturn'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.updateManyAndReturn')); + }, + async delete({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.delete'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.delete'); + return query(args); + }, + async deleteMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.deleteMany'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.deleteMany')); + }, + async upsert({ args, query }: QueryHookParams) { + const upsertArgs = args as UserMcpServerUpsertArgs; + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.upsert'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.upsert'); + assertCreateDataForUser(upsertArgs.create, user, 'userMcpServer.upsert'); + assertNoIdentityMutation(upsertArgs.update, 'userMcpServer.upsert'); + return query(args); + }, + }, + user: { + async $allOperations({ operation, args, query }: AllOperationsHookParams) { + if (operation === 'deleteMany') { + rejectUserDeleteMany(); + } + // The owner-only user deletion API intentionally deletes one user and relies on + // cascade to remove that user's rows. Bulk deletes stay blocked above. + assertNoUserMcpServerRelationAccess(args, `user.${operation}`); + return query(args); + }, + }, + mcpServer: { + async $allOperations({ operation, args, query }: AllOperationsHookParams) { + if (operation === 'delete' || operation === 'deleteMany') { + rejectSharedMcpServerDelete(`mcpServer.${operation}`); + } + assertNoUserMcpServerRelationAccess(args, `mcpServer.${operation}`); + return query(args); + }, + }, + org: { + $allOperations: guardMcpParentOperation('org', assertNoMcpServerRelationAccess), + }, + userToOrg: { + $allOperations: guardMcpParentOperation('userToOrg', assertNoMcpServerRelationAccess), + }, +}); diff --git a/packages/web/src/lib/errorCodes.ts b/packages/web/src/lib/errorCodes.ts index 714932c30..fdb09d67d 100644 --- a/packages/web/src/lib/errorCodes.ts +++ b/packages/web/src/lib/errorCodes.ts @@ -35,4 +35,6 @@ export enum ErrorCode { LAST_OWNER_CANNOT_BE_DEMOTED = 'LAST_OWNER_CANNOT_BE_DEMOTED', LAST_OWNER_CANNOT_BE_REMOVED = 'LAST_OWNER_CANNOT_BE_REMOVED', API_KEY_USAGE_DISABLED = 'API_KEY_USAGE_DISABLED', + MCP_SERVER_ALREADY_EXISTS = 'MCP_SERVER_ALREADY_EXISTS', + MCP_SERVER_NOT_FOUND = 'MCP_SERVER_NOT_FOUND', } diff --git a/packages/web/src/middleware/withAuth.test.ts b/packages/web/src/middleware/withAuth.test.ts index 862677df9..6da2a9afe 100644 --- a/packages/web/src/middleware/withAuth.test.ts +++ b/packages/web/src/middleware/withAuth.test.ts @@ -6,6 +6,7 @@ import { MOCK_API_KEY, MOCK_OAUTH_TOKEN, MOCK_ORG, MOCK_USER_WITH_ACCOUNTS, pris import { OrgRole } from '@sourcebot/db'; import { ErrorCode } from '../lib/errorCodes'; import { StatusCodes } from 'http-status-codes'; +import { userScopedPrismaClientExtension } from '@/prisma'; const mocks = vi.hoisted(() => { return { @@ -80,6 +81,7 @@ const createMockSession = (overrides: Partial = {}): Session => ({ beforeEach(() => { vi.clearAllMocks(); + vi.mocked(userScopedPrismaClientExtension).mockReset(); mocks.auth.mockResolvedValue(null); mocks.headers.mockResolvedValue(new Headers()); mocks.hasEntitlement.mockReturnValue(false); @@ -471,6 +473,39 @@ describe('getAuthContext', () => { }); describe('withAuth', () => { + test('should pass the scoped prisma client from $extends to the callback', async () => { + const userId = 'test-user-id'; + const user = { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }; + const extension = { query: { userMcpServer: {} } }; + const scopedPrisma = { scoped: true }; + + prisma.user.findUnique.mockResolvedValue(user); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + vi.mocked(userScopedPrismaClientExtension).mockResolvedValue(extension as never); + prisma.$extends.mockReturnValue(scopedPrisma as never); + setMockSession(createMockSession({ user: { id: userId } })); + + const cb = vi.fn(); + await withAuth(cb); + + expect(userScopedPrismaClientExtension).toHaveBeenCalledWith(user); + expect(prisma.$extends).toHaveBeenCalledWith(extension); + expect(cb).toHaveBeenCalledWith(expect.objectContaining({ + prisma: scopedPrisma, + })); + }); + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization', async () => { const userId = 'test-user-id'; prisma.user.findUnique.mockResolvedValue({ diff --git a/packages/web/src/prisma.ts b/packages/web/src/prisma.ts index f863f5ef7..0496d8c9b 100644 --- a/packages/web/src/prisma.ts +++ b/packages/web/src/prisma.ts @@ -2,6 +2,7 @@ import 'server-only'; import { env, getDBConnectionString } from "@sourcebot/shared"; import { Prisma, PrismaClient, UserWithAccounts } from "@sourcebot/db"; import { hasEntitlement } from "@/lib/entitlements"; +import { getMcpPrismaQueryExtension } from "@/features/mcp/prismaScope"; // @see: https://authjs.dev/getting-started/adapters/prisma const globalForPrisma = globalThis as unknown as { prisma: PrismaClient } @@ -35,6 +36,7 @@ export const userScopedPrismaClientExtension = async (user?: UserWithAccounts) = (prisma) => { return prisma.$extends({ query: { + ...getMcpPrismaQueryExtension(user), ...(hasPermissionSyncing ? { repo: { async $allOperations({ args, query }) { diff --git a/yarn.lock b/yarn.lock index 7be7eb0ae..4aaba200f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -99,6 +99,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/mcp@npm:^2.0.0-beta.11": + version: 2.0.0-beta.11 + resolution: "@ai-sdk/mcp@npm:2.0.0-beta.11" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@ai-sdk/provider-utils": "npm:5.0.0-beta.7" + pkce-challenge: "npm:^5.0.0" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/efcc9b9f5f8b20b78b2d0ee6d83b34466b2ec456c3b40b5b8b10af226e7d3f6144f964d87a20c5fc54c24b21f3610cb75cc246c30833b99fb501438a206c9933 + languageName: node + linkType: hard + "@ai-sdk/mistral@npm:^3.0.30": version: 3.0.30 resolution: "@ai-sdk/mistral@npm:3.0.30" @@ -148,6 +161,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:5.0.0-beta.7": + version: 5.0.0-beta.7 + resolution: "@ai-sdk/provider-utils@npm:5.0.0-beta.7" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@standard-schema/spec": "npm:^1.1.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/440825f7b599da6a0bd830c905f9ba4f21defcf7068bc98154ea38158c1134b049cb2815047013668f48b679a23de1d3c19eb072a65115dc860070168104c99e + languageName: node + linkType: hard + "@ai-sdk/provider@npm:3.0.8": version: 3.0.8 resolution: "@ai-sdk/provider@npm:3.0.8" @@ -157,6 +183,15 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider@npm:4.0.0-beta.5": + version: 4.0.0-beta.5 + resolution: "@ai-sdk/provider@npm:4.0.0-beta.5" + dependencies: + json-schema: "npm:^0.4.0" + checksum: 10c0/886f5892268cc3425130c9b019a9eb1e2acdb5efd05d920b05d1ac1ab49603393d8e509e6e0a3c46dee533a411a51a2af2c6fa0a173b41130f5175a615add7fb + languageName: node + linkType: hard + "@ai-sdk/react@npm:^3.0.169": version: 3.0.169 resolution: "@ai-sdk/react@npm:3.0.169" @@ -9062,6 +9097,7 @@ __metadata: "@ai-sdk/deepseek": "npm:^2.0.29" "@ai-sdk/google": "npm:^3.0.64" "@ai-sdk/google-vertex": "npm:^4.0.111" + "@ai-sdk/mcp": "npm:^2.0.0-beta.11" "@ai-sdk/mistral": "npm:^3.0.30" "@ai-sdk/openai": "npm:^3.0.53" "@ai-sdk/openai-compatible": "npm:^2.0.41" @@ -9274,7 +9310,7 @@ __metadata: vitest: "npm:^4.1.4" vitest-mock-extended: "npm:^4.0.0" vscode-icons-js: "npm:^11.6.1" - zod: "npm:^3.25.74" + zod: "npm:^3.25.76" zod-to-json-schema: "npm:^3.24.5" languageName: unknown linkType: soft @@ -18526,13 +18562,20 @@ __metadata: languageName: node linkType: hard -"picomatch@npm:^4.0.2, picomatch@npm:^4.0.3, picomatch@npm:^4.0.4": +"picomatch@npm:^4.0.2, picomatch@npm:^4.0.4": version: 4.0.4 resolution: "picomatch@npm:4.0.4" checksum: 10c0/e2c6023372cc7b5764719a5ffb9da0f8e781212fa7ca4bd0562db929df8e117460f00dff3cb7509dacfc06b86de924b247f504d0ce1806a37fac4633081466b0 languageName: node linkType: hard +"picomatch@npm:^4.0.3": + version: 4.0.3 + resolution: "picomatch@npm:4.0.3" + checksum: 10c0/9582c951e95eebee5434f59e426cddd228a7b97a0161a375aed4be244bd3fe8e3a31b846808ea14ef2c8a2527a6eeab7b3946a67d5979e81694654f939473ae2 + languageName: node + linkType: hard + "picospinner@npm:^3.0.0": version: 3.0.0 resolution: "picospinner@npm:3.0.0" @@ -23045,7 +23088,7 @@ __metadata: languageName: node linkType: hard -"zod@npm:^3.25.0": +"zod@npm:^3.25.0, zod@npm:^3.25.76": version: 3.25.76 resolution: "zod@npm:3.25.76" checksum: 10c0/5718ec35e3c40b600316c5b4c5e4976f7fee68151bc8f8d90ec18a469be9571f072e1bbaace10f1e85cf8892ea12d90821b200e980ab46916a6166a4260a983c