From b0ec624550b7c25b96d98d23c4ffe163758488c5 Mon Sep 17 00:00:00 2001 From: James Grugett Date: Thu, 30 Apr 2026 12:15:53 -0700 Subject: [PATCH] Add Fireworks fallback for CanopyWave --- .../completions/__tests__/completions.test.ts | 153 ++++++++++++++++++ web/src/app/api/v1/chat/completions/_post.ts | 94 ++++++++--- 2 files changed, 227 insertions(+), 20 deletions(-) diff --git a/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts b/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts index cf846131c..715de7a7c 100644 --- a/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts +++ b/web/src/app/api/v1/chat/completions/__tests__/completions.test.ts @@ -6,6 +6,7 @@ import { FREEBUFF_GLM_MODEL_ID, isFreebuffDeploymentHours, } from '@codebuff/common/constants/freebuff-models' +import { env } from '@codebuff/internal/env' import { formatQuotaResetCountdown, postChatCompletions } from '../_post' import { checkFreeModeRateLimit, @@ -1075,6 +1076,116 @@ describe('/api/v1/chat/completions POST endpoint', () => { }) describe('Successful responses', () => { + const withCanopyWaveApiKey = async (testFn: () => Promise) => { + const previousCanopyWaveApiKey = env.CANOPYWAVE_API_KEY + env.CANOPYWAVE_API_KEY = 'test' + try { + await testFn() + } finally { + env.CANOPYWAVE_API_KEY = previousCanopyWaveApiKey + } + } + + const createCanopyWaveFallbackRequest = (stream: boolean) => + new NextRequest('http://localhost:3000/api/v1/chat/completions', { + method: 'POST', + headers: { Authorization: 'Bearer test-api-key-123' }, + body: JSON.stringify({ + model: 'minimax/minimax-m2.5', + stream, + codebuff_metadata: { + run_id: 'run-123', + client_id: 'test-client-id-123', + client_request_id: 'test-client-session-id-123', + }, + }), + }) + + const createCanopyWaveNoWorkersThenFireworksFetch = (stream: boolean) => { + const fetchedBodies: Record[] = [] + const fetch = mock( + async (_url: string | URL | Request, init?: RequestInit) => { + fetchedBodies.push(JSON.parse(init?.body as string)) + + if (fetchedBodies.length === 1) { + return Response.json( + { + error: { + message: 'No available workers', + code: 'no_available_workers', + }, + }, + { status: 503 }, + ) + } + + if (!stream) { + return Response.json({ + id: 'test-id', + model: 'accounts/fireworks/models/minimax-m2p5', + choices: [{ message: { content: 'fireworks response' } }], + usage: { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }, + }) + } + + const encoder = new TextEncoder() + const fireworksStream = new ReadableStream({ + start(controller) { + controller.enqueue( + encoder.encode( + 'data: {"id":"test-id","model":"accounts/fireworks/models/minimax-m2p5","choices":[{"delta":{"content":"test"}}]}\n\n', + ), + ) + controller.enqueue(encoder.encode('data: [DONE]\n\n')) + controller.close() + }, + }) + + return new Response(fireworksStream, { + status: 200, + headers: { 'Content-Type': 'text/event-stream' }, + }) + }, + ) as unknown as typeof globalThis.fetch + + return { fetch, fetchedBodies } + } + + const postCanopyWaveFallbackRequest = async ({ + fetch, + stream, + }: { + fetch: typeof globalThis.fetch + stream: boolean + }) => + postChatCompletions({ + req: createCanopyWaveFallbackRequest(stream), + getUserInfoFromApiKey: mockGetUserInfoFromApiKey, + logger: mockLogger, + trackEvent: mockTrackEvent, + getUserUsageData: mockGetUserUsageData, + getAgentRunFromId: mockGetAgentRunFromId, + fetch, + insertMessageBigquery: mockInsertMessageBigquery, + loggerWithContext: mockLoggerWithContext, + checkSessionAdmissible: mockCheckSessionAdmissibleAllow, + }) + + const expectCanopyWaveThenFireworks = ( + fetchedBodies: Record[], + ) => { + expect(fetchedBodies).toHaveLength(2) + expect(fetchedBodies[0].model).toBe('minimax/minimax-m2.5') + expect(fetchedBodies[1].model).toBe( + 'accounts/fireworks/models/minimax-m2p5', + ) + expect(mockLogger.warn).toHaveBeenCalled() + } + it('returns stream with correct headers', async () => { const req = new NextRequest( 'http://localhost:3000/api/v1/chat/completions', @@ -1158,6 +1269,48 @@ describe('/api/v1/chat/completions POST endpoint', () => { }, FETCH_PATH_TEST_TIMEOUT_MS, ) + + it( + 'falls back to Fireworks when CanopyWave has no available workers for non-streaming requests', + async () => { + await withCanopyWaveApiKey(async () => { + const { fetch, fetchedBodies } = + createCanopyWaveNoWorkersThenFireworksFetch(false) + const response = await postCanopyWaveFallbackRequest({ + fetch, + stream: false, + }) + + expect(response.status).toBe(200) + expectCanopyWaveThenFireworks(fetchedBodies) + + const body = await response.json() + expect(body.model).toBe('minimax/minimax-m2.5') + expect(body.provider).toBe('Fireworks') + expect(body.choices[0].message.content).toBe('fireworks response') + }) + }, + FETCH_PATH_TEST_TIMEOUT_MS, + ) + + it( + 'falls back to Fireworks when CanopyWave has no available workers for streaming requests', + async () => { + await withCanopyWaveApiKey(async () => { + const { fetch, fetchedBodies } = + createCanopyWaveNoWorkersThenFireworksFetch(true) + const response = await postCanopyWaveFallbackRequest({ + fetch, + stream: true, + }) + + expect(response.status).toBe(200) + expect(response.headers.get('Content-Type')).toBe('text/event-stream') + expectCanopyWaveThenFireworks(fetchedBodies) + }) + }, + FETCH_PATH_TEST_TIMEOUT_MS, + ) }) describe('Subscription limit enforcement', () => { diff --git a/web/src/app/api/v1/chat/completions/_post.ts b/web/src/app/api/v1/chat/completions/_post.ts index 0a7771d46..8a4d620e0 100644 --- a/web/src/app/api/v1/chat/completions/_post.ts +++ b/web/src/app/api/v1/chat/completions/_post.ts @@ -109,6 +109,50 @@ export const formatQuotaResetCountdown = ( return `in ${pluralize(minutes, 'minute')}` } +type ProviderHandlerArgs = Parameters[0] +type ProviderHandler = (args: ProviderHandlerArgs) => Promise + +function shouldFallbackCanopyWaveToFireworks( + error: unknown, + model: string, +): error is CanopyWaveError { + if (!(error instanceof CanopyWaveError) || !isFireworksModel(model)) { + return false + } + const message = error.errorBody.error.message.toLowerCase() + return ( + error.statusCode === 429 || + error.statusCode >= 500 || + message.includes('no available workers') + ) +} + +async function handleCanopyWaveWithFireworksFallback( + args: ProviderHandlerArgs, + handleCanopyWave: ProviderHandler, + handleFireworks: ProviderHandler, +): Promise { + try { + return await handleCanopyWave(args) + } catch (error) { + if (!shouldFallbackCanopyWaveToFireworks(error, args.body.model)) { + throw error + } + + args.logger.warn( + { + error: getErrorObject(error), + model: args.body.model, + providerStatusCode: error.statusCode, + providerStatusText: error.statusText, + }, + 'CanopyWave request failed, falling back to Fireworks', + ) + + return handleFireworks(args) + } +} + export type CheckSessionAdmissibleFn = typeof checkSessionAdmissible type GateRejectCode = Extract['code'] @@ -599,7 +643,8 @@ export async function postChatCompletions(params: { if (bodyStream) { // Streaming request — route to SiliconFlow/CanopyWave/Fireworks for supported models const useSiliconFlow = false // isSiliconFlowModel(typedBody.model) - const useCanopyWave = isCanopyWaveModel(typedBody.model) + const useCanopyWave = + !!env.CANOPYWAVE_API_KEY && isCanopyWaveModel(typedBody.model) const useFireworks = !useCanopyWave && isFireworksModel(typedBody.model) const useOpenAIDirect = !useCanopyWave && @@ -616,15 +661,19 @@ export async function postChatCompletions(params: { insertMessageBigquery, }) : useCanopyWave - ? await handleCanopyWaveStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - fetch, - logger, - insertMessageBigquery, - }) + ? await handleCanopyWaveWithFireworksFallback( + { + body: typedBody, + userId, + stripeCustomerId, + agentId, + fetch, + logger, + insertMessageBigquery, + }, + handleCanopyWaveStream, + handleFireworksStream, + ) : useFireworks ? await handleFireworksStream({ body: typedBody, @@ -678,7 +727,8 @@ export async function postChatCompletions(params: { // Non-streaming request — route to SiliconFlow/CanopyWave/Fireworks for supported models const model = typedBody.model const useSiliconFlow = false // isSiliconFlowModel(model) - const useCanopyWave = isCanopyWaveModel(model) + const useCanopyWave = + !!env.CANOPYWAVE_API_KEY && isCanopyWaveModel(model) const useFireworks = !useCanopyWave && isFireworksModel(model) const shouldUseOpenAIEndpoint = !useCanopyWave && !useFireworks && isOpenAIDirectModel(model) @@ -694,15 +744,19 @@ export async function postChatCompletions(params: { insertMessageBigquery, }) : useCanopyWave - ? handleCanopyWaveNonStream({ - body: typedBody, - userId, - stripeCustomerId, - agentId, - fetch, - logger, - insertMessageBigquery, - }) + ? handleCanopyWaveWithFireworksFallback( + { + body: typedBody, + userId, + stripeCustomerId, + agentId, + fetch, + logger, + insertMessageBigquery, + }, + handleCanopyWaveNonStream, + handleFireworksNonStream, + ) : useFireworks ? handleFireworksNonStream({ body: typedBody,