diff --git a/packages/api-client/src/fetcher.test.ts b/packages/api-client/src/fetcher.test.ts index b75e4bdff2..6c0352e30f 100644 --- a/packages/api-client/src/fetcher.test.ts +++ b/packages/api-client/src/fetcher.test.ts @@ -26,7 +26,6 @@ describe("buildApiFetcher", () => { }; return response; }; - beforeEach(() => { vi.resetAllMocks(); vi.stubGlobal("fetch", mockFetch); diff --git a/packages/api-client/src/fetcher.ts b/packages/api-client/src/fetcher.ts index b9a7a4ed50..802260973f 100644 --- a/packages/api-client/src/fetcher.ts +++ b/packages/api-client/src/fetcher.ts @@ -10,7 +10,6 @@ export const buildApiFetcher: ( config: ApiFetcherConfig, ) => Parameters[0] = (config) => { const userAgent = `posthog/desktop.hog.dev; version: ${config.appVersion}`; - const makeRequest = async ( input: Parameters[0]["fetch"]>[0], token: string, diff --git a/packages/api-client/src/posthog-client.test.ts b/packages/api-client/src/posthog-client.test.ts index cd660438df..22041f0d4d 100644 --- a/packages/api-client/src/posthog-client.test.ts +++ b/packages/api-client/src/posthog-client.test.ts @@ -219,6 +219,49 @@ describe("PostHogAPIClient", () => { ); }); + it("returns the redirect URL when authorizing an MCP installation", async () => { + const fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => ({ + redirect_url: "https://auth.example.com/authorize?state=abc", + }), + }); + const client = new PostHogAPIClient( + "http://localhost:8000", + async () => "token", + async () => "token", + 123, + ); + + ( + client as unknown as { + api: { baseUrl: string; fetcher: { fetch: typeof fetch } }; + } + ).api = { + baseUrl: "http://localhost:8000", + fetcher: { fetch }, + }; + + await expect( + client.authorizeMcpInstallation({ + installation_id: "inst-123", + install_source: "posthog-code", + posthog_code_callback_url: "posthog-code://mcp-oauth-complete", + }), + ).resolves.toEqual({ + redirect_url: "https://auth.example.com/authorize?state=abc", + }); + + expect(fetch).toHaveBeenCalledWith( + expect.objectContaining({ + method: "get", + path: "/api/environments/123/mcp_server_installations/authorize/", + }), + ); + expect(fetch.mock.calls[0][0]).not.toHaveProperty("overrides"); + }); + describe("warmTask", () => { function makeClient(fetch: ReturnType) { const client = new PostHogAPIClient( diff --git a/packages/core/src/auth/auth.test.ts b/packages/core/src/auth/auth.test.ts index 1b9448e04a..80384696ee 100644 --- a/packages/core/src/auth/auth.test.ts +++ b/packages/core/src/auth/auth.test.ts @@ -708,6 +708,72 @@ describe("AuthService", () => { expect(service.getState().status).toBe("restoring"); expect(oauthFlow.refreshToken).toHaveBeenCalledTimes(3); }); + + it("uses the current access token when a preemptive refresh fails before expiry", async () => { + vi.useFakeTimers(); + try { + oauthFlow.startFlow.mockResolvedValue( + mockTokenResponse({ + accessToken: "current-access-token", + refreshToken: "current-refresh-token", + }), + ); + stubAuthFetch(); + + await service.initialize(); + await service.login("us"); + + oauthFlow.refreshToken.mockReset(); + oauthFlow.refreshToken.mockResolvedValue({ + success: false, + error: "Token refresh failed: 500 Internal Server Error", + errorCode: "server_error", + }); + + await vi.advanceTimersByTimeAsync(3_599_500); + + await expect(service.getValidAccessToken()).resolves.toMatchObject({ + accessToken: "current-access-token", + }); + expect(oauthFlow.refreshToken).toHaveBeenCalledTimes(3); + expect(service.getState().status).toBe("authenticated"); + } finally { + vi.useRealTimers(); + } + }); + + it("does not use the current access token when refresh token auth fails", async () => { + vi.useFakeTimers(); + try { + oauthFlow.startFlow.mockResolvedValue( + mockTokenResponse({ + accessToken: "current-access-token", + refreshToken: "current-refresh-token", + }), + ); + stubAuthFetch(); + + await service.initialize(); + await service.login("us"); + + oauthFlow.refreshToken.mockReset(); + oauthFlow.refreshToken.mockResolvedValue({ + success: false, + error: "Token revoked", + errorCode: "auth_error", + }); + + await vi.advanceTimersByTimeAsync(3_599_500); + + await expect(service.getValidAccessToken()).rejects.toThrow( + "Token revoked", + ); + expect(service.getState().status).toBe("anonymous"); + expect(sessionPort.getCurrent()).toBeNull(); + } finally { + vi.useRealTimers(); + } + }); }); describe("transient org fetch failures", () => { diff --git a/packages/core/src/auth/auth.ts b/packages/core/src/auth/auth.ts index b7b47b424c..77e5111734 100644 --- a/packages/core/src/auth/auth.ts +++ b/packages/core/src/auth/auth.ts @@ -487,12 +487,13 @@ export class AuthService extends TypedEventEmitter { private async ensureValidSession( forceRefresh = false, ): Promise { + const currentSession = this.session; if ( - this.session && + currentSession && !forceRefresh && - !this.isSessionExpiring(this.session) + !this.isSessionExpiring(currentSession) ) { - return this.session; + return currentSession; } if (this.refreshPromise) { @@ -502,7 +503,24 @@ export class AuthService extends TypedEventEmitter { const sessionInput = this.getSessionInputForRefresh(); const refreshAndSync = async (): Promise => { - const session = await this.refreshSession(sessionInput); + let session: InMemorySession; + try { + session = await this.refreshSession(sessionInput); + } catch (error) { + if ( + currentSession && + this.session === currentSession && + !forceRefresh && + !this.isSessionExpired(currentSession) + ) { + this.logger.warn( + "Preemptive session refresh failed; using current access token", + { error }, + ); + return currentSession; + } + throw error; + } await this.syncAuthenticatedSession(session); return session; }; @@ -833,6 +851,9 @@ export class AuthService extends TypedEventEmitter { private isSessionExpiring(session: InMemorySession): boolean { return session.accessTokenExpiresAt - Date.now() <= TOKEN_EXPIRY_SKEW_MS; } + private isSessionExpired(session: InMemorySession): boolean { + return session.accessTokenExpiresAt <= Date.now(); + } private async fetchUserContext( accessToken: string, cloudRegion: CloudRegion, diff --git a/packages/ui/src/features/mcp-servers/hooks/useMcpServers.test.tsx b/packages/ui/src/features/mcp-servers/hooks/useMcpServers.test.tsx new file mode 100644 index 0000000000..a61175365f --- /dev/null +++ b/packages/ui/src/features/mcp-servers/hooks/useMcpServers.test.tsx @@ -0,0 +1,112 @@ +import type { McpRecommendedServer } from "@posthog/api-client/posthog-client"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { act, renderHook, waitFor } from "@testing-library/react"; +import type { ReactNode } from "react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const mockClient = vi.hoisted(() => ({ + getMcpServerInstallations: vi.fn(), + getMcpServers: vi.fn(), + installMcpTemplate: vi.fn(), + installCustomMcpServer: vi.fn(), + uninstallMcpServer: vi.fn(), + updateMcpServerInstallation: vi.fn(), + authorizeMcpInstallation: vi.fn(), +})); + +const mockTrpcClient = vi.hoisted(() => ({ + mcpCallback: { + getCallbackUrl: { query: vi.fn() }, + openAndWaitForCallback: { mutate: vi.fn() }, + }, +})); + +const mockTrpc = vi.hoisted(() => ({ + mcpCallback: { + onOAuthComplete: { + subscriptionOptions: vi.fn(() => ({})), + }, + }, +})); + +vi.mock("@posthog/ui/features/auth/authClient", () => ({ + useOptionalAuthenticatedClient: () => mockClient, +})); + +vi.mock("@posthog/host-router/react", () => ({ + useHostTRPC: () => mockTrpc, + useHostTRPCClient: () => mockTrpcClient, +})); + +vi.mock("@trpc/tanstack-react-query", () => ({ + useSubscription: vi.fn(), +})); + +vi.mock("sonner", () => ({ + toast: { + error: vi.fn(), + success: vi.fn(), + }, +})); + +import { useMcpServers } from "./useMcpServers"; + +function wrapper({ children }: { children: ReactNode }) { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, + }); + return ( + {children} + ); +} + +const template = { + id: "granola", + name: "Granola", + auth_type: "oauth", +} as McpRecommendedServer; + +describe("useMcpServers", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockClient.getMcpServerInstallations.mockResolvedValue([]); + mockClient.getMcpServers.mockResolvedValue([]); + mockTrpcClient.mcpCallback.getCallbackUrl.query.mockResolvedValue({ + callbackUrl: "posthog-code://mcp-oauth-complete", + }); + }); + + it("reverts template connect loading state after a failed install", async () => { + let rejectInstall!: (error: Error) => void; + mockClient.installMcpTemplate.mockReturnValue( + new Promise((_resolve, reject) => { + rejectInstall = reject; + }), + ); + + const { result } = renderHook(() => useMcpServers(), { wrapper }); + + act(() => { + result.current.installTemplate(template); + }); + + await waitFor(() => expect(result.current.installingId).toBe("granola")); + await waitFor(() => + expect(mockClient.installMcpTemplate).toHaveBeenCalledWith({ + template_id: "granola", + install_source: "posthog-code", + posthog_code_callback_url: "posthog-code://mcp-oauth-complete", + api_key: undefined, + }), + ); + + await act(async () => { + rejectInstall(new Error("Connection failed")); + }); + + await waitFor(() => expect(result.current.installingId).toBeNull()); + }); +}); diff --git a/packages/ui/src/features/mcp-servers/hooks/useMcpServers.ts b/packages/ui/src/features/mcp-servers/hooks/useMcpServers.ts index 50ca0dac80..758651b0e1 100644 --- a/packages/ui/src/features/mcp-servers/hooks/useMcpServers.ts +++ b/packages/ui/src/features/mcp-servers/hooks/useMcpServers.ts @@ -14,7 +14,7 @@ import { useAuthenticatedMutation } from "@posthog/ui/hooks/useAuthenticatedMuta import { useAuthenticatedQuery } from "@posthog/ui/hooks/useAuthenticatedQuery"; import { useQueryClient } from "@tanstack/react-query"; import { useSubscription } from "@trpc/tanstack-react-query"; -import { useCallback, useMemo, useState } from "react"; +import { useCallback, useMemo } from "react"; import { toast } from "sonner"; export const mcpKeys = { @@ -38,7 +38,6 @@ export function useMcpServers() { const trpc = useHostTRPC(); const trpcClient = useHostTRPCClient(); const oauth = useMemo(() => createOAuthCallback(trpcClient), [trpcClient]); - const [installingId, setInstallingId] = useState(null); const queryClient = useQueryClient(); const { data: installations, isLoading: installationsLoading } = @@ -120,18 +119,15 @@ export function useMcpServers() { toast.error(data.error); } invalidateInstallations(); - setInstallingId(null); }, onError: (error: Error) => { toast.error(error.message || "Failed to connect server"); - setInstallingId(null); }, }, ); const installTemplate = useCallback( (template: McpRecommendedServer, opts?: { api_key?: string }) => { - setInstallingId(template.id); installTemplateMutation.mutate({ template_id: template.id, api_key: opts?.api_key, @@ -140,6 +136,10 @@ export function useMcpServers() { [installTemplateMutation], ); + const installingId = installTemplateMutation.isPending + ? (installTemplateMutation.variables?.template_id ?? null) + : null; + const installCustomMutation = useAuthenticatedMutation( ( client,