diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts b/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts index 8ae3b368c..56cff686e 100644 --- a/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts +++ b/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts @@ -21,7 +21,8 @@ import { } from '@shared/imageGenerationSettings' import { isChatAudioTtsModel, - isStandardTtsModel, + isGeminiGenerateContentTtsModel, + isTtsModelId, isTtsModelConfig, normalizeTtsSettings, ttsFormatToMimeType @@ -34,7 +35,11 @@ import { mapMessagesToModelMessages } from './messageMapper' import { buildProviderOptions } from './providerOptionsMapper' import { ProxyAgent } from 'undici' import { proxyConfig } from '../../proxyConfig' -import { type AiSdkProviderKind, createAiSdkProviderContext } from './providerFactory' +import { + type AiSdkProviderKind, + createAiSdkProviderContext, + normalizeGeminiBaseUrl +} from './providerFactory' import { adaptAiSdkStream } from './streamAdapter' type ImageGenerationProviderPayload = Record @@ -43,6 +48,10 @@ type ImageGenerationRequestOptions = { providerOptions?: Record } +const DEFAULT_GEMINI_TTS_VOICE = 'Kore' +const DEFAULT_GEMINI_PCM_SAMPLE_RATE = 24000 +const DEFAULT_GEMINI_PCM_BITS_PER_SAMPLE = 16 + export interface AiSdkRuntimeContext { providerKind: AiSdkProviderKind provider: LLM_PROVIDER @@ -180,11 +189,91 @@ function shouldUseTtsRuntime( return ( modelConfig.apiEndpoint === ApiEndpointType.AudioSpeech || isTtsModelConfig(modelConfig) || - isStandardTtsModel(modelId) || - isChatAudioTtsModel(modelId) + isTtsModelId(modelId) ) } +function buildGeminiTtsPrompt(text: string, instructions?: string): string { + if (instructions?.trim()) { + return `${instructions.trim()}\n\n${text}`.trim() + } + + return text.trim() +} + +function resolveGeminiTtsBaseUrl(provider: LLM_PROVIDER): string { + const rawBaseUrl = (provider.baseUrl || '').trim() + + if (provider.apiType === 'gemini' || provider.id === 'gemini') { + return normalizeGeminiBaseUrl(rawBaseUrl || undefined) + } + + if (rawBaseUrl) { + try { + const parsed = new URL(rawBaseUrl.includes('://') ? rawBaseUrl : `https://${rawBaseUrl}`) + if (provider.id === 'aihubmix' || /(^|\.)aihubmix\.com$/i.test(parsed.hostname)) { + return normalizeGeminiBaseUrl(`${parsed.origin}/gemini`) + } + } catch { + // Fall through to provider-specific fallback below. + } + } + + if (provider.id === 'aihubmix') { + return normalizeGeminiBaseUrl('https://aihubmix.com/gemini') + } + + return normalizeGeminiBaseUrl(rawBaseUrl || undefined) +} + +function normalizeGeminiTtsResponseAudio( + base64: string, + mimeType: string | undefined +): { base64: string; mimeType: string } { + const normalizedMimeType = (mimeType || '').trim() + const lowerMimeType = normalizedMimeType.toLowerCase() + + if (!lowerMimeType || !(lowerMimeType.includes('l16') || lowerMimeType.includes('audio/pcm'))) { + return { + base64, + mimeType: normalizedMimeType || 'audio/wav' + } + } + + const sampleRate = Number(/(?:rate|samplerate)=(\d+)/i.exec(normalizedMimeType)?.[1]) + const bitsPerSample = Number(/(?:bits|bitspersample)=(\d+)/i.exec(normalizedMimeType)?.[1]) + const pcmBuffer = Buffer.from(base64, 'base64') + const resolvedSampleRate = + Number.isFinite(sampleRate) && sampleRate > 0 ? sampleRate : DEFAULT_GEMINI_PCM_SAMPLE_RATE + const resolvedBitsPerSample = + Number.isFinite(bitsPerSample) && bitsPerSample > 0 + ? bitsPerSample + : DEFAULT_GEMINI_PCM_BITS_PER_SAMPLE + const blockAlign = resolvedBitsPerSample / 8 + const byteRate = resolvedSampleRate * blockAlign + const wavBuffer = Buffer.alloc(44 + pcmBuffer.length) + + wavBuffer.write('RIFF', 0) + wavBuffer.writeUInt32LE(36 + pcmBuffer.length, 4) + wavBuffer.write('WAVE', 8) + wavBuffer.write('fmt ', 12) + wavBuffer.writeUInt32LE(16, 16) + wavBuffer.writeUInt16LE(1, 20) + wavBuffer.writeUInt16LE(1, 22) + wavBuffer.writeUInt32LE(resolvedSampleRate, 24) + wavBuffer.writeUInt32LE(byteRate, 28) + wavBuffer.writeUInt16LE(blockAlign, 32) + wavBuffer.writeUInt16LE(resolvedBitsPerSample, 34) + wavBuffer.write('data', 36) + wavBuffer.writeUInt32LE(pcmBuffer.length, 40) + pcmBuffer.copy(wavBuffer, 44) + + return { + base64: wavBuffer.toString('base64'), + mimeType: 'audio/wav' + } +} + /** * Extracts the text to be synthesized from the last user message in the conversation. */ @@ -279,7 +368,10 @@ async function executeTtsPatternB( const body: Record = { model: modelId, - messages: [{ role: 'user', content: text }], + messages: [ + { role: 'user', content: text }, + { role: 'assistant', content: text } + ], modalities: ['text', 'audio'], audio: { format, @@ -333,6 +425,94 @@ async function executeTtsPatternB( } } +async function executeTtsPatternC( + provider: LLM_PROVIDER, + defaultHeaders: Record, + text: string, + modelId: string, + modelConfig: ModelConfig, + timeout: number | undefined +): Promise<{ base64: string; mimeType: string }> { + const tts = normalizeTtsSettings(modelConfig.tts) + const baseUrl = resolveGeminiTtsBaseUrl(provider) + const requestModelId = modelId.trim().split('/').at(-1) || modelId + const url = `${baseUrl}/models/${encodeURIComponent(requestModelId)}:generateContent` + const body: Record = { + contents: [ + { + role: 'user', + parts: [ + { + text: buildGeminiTtsPrompt(text, tts?.instructions) + } + ] + } + ], + generationConfig: { + responseModalities: ['AUDIO'], + speechConfig: { + voiceConfig: { + prebuiltVoiceConfig: { + voiceName: tts?.voice ?? DEFAULT_GEMINI_TTS_VOICE + } + } + } + } + } + + const controller = new AbortController() + const timeoutId = timeout ? setTimeout(() => controller.abort(), timeout) : undefined + const proxyUrl = proxyConfig.getProxyUrl() + const dispatcher = proxyUrl ? new ProxyAgent(proxyUrl) : undefined + + try { + const fetchInit: RequestInit & { dispatcher?: ProxyAgent } = { + method: 'POST', + headers: { + ...defaultHeaders, + 'Content-Type': 'application/json', + 'x-goog-api-key': provider.oauthToken || provider.apiKey || '' + }, + body: JSON.stringify(body), + signal: controller.signal + } + if (dispatcher) fetchInit.dispatcher = dispatcher + const response = await fetch(url, fetchInit) + + if (!response.ok) { + const errText = await response.text().catch(() => '') + throw new Error(`TTS (gemini) request failed (${response.status}): ${errText}`) + } + + const json = (await response.json()) as { + candidates?: Array<{ + content?: { + parts?: Array<{ + inlineData?: { data?: string; mimeType?: string } + inline_data?: { data?: string; mime_type?: string } + }> + } + }> + } + const firstPart = json.candidates?.[0]?.content?.parts?.find( + (part) => part.inlineData?.data || part.inline_data?.data + ) + const inlineData = firstPart?.inlineData + const legacyInlineData = firstPart?.inline_data + const audioData = inlineData?.data ?? legacyInlineData?.data + if (!audioData) { + throw new Error('TTS response missing inline audio data in candidates[0].content.parts') + } + + return normalizeGeminiTtsResponseAudio( + audioData, + inlineData?.mimeType ?? legacyInlineData?.mime_type + ) + } finally { + if (timeoutId !== undefined) clearTimeout(timeoutId) + } +} + function resolveRequestTimeout(modelConfig: ModelConfig): number | undefined { const timeout = modelConfig.timeout if (typeof timeout !== 'number' || !Number.isFinite(timeout) || timeout <= 0) { @@ -576,17 +756,10 @@ export async function* runAiSdkCoreStream( if (shouldUseTtsRuntime(context, modelId, normalizedModelConfig)) { const text = extractTtsText(messages) const usePatternB = isChatAudioTtsModel(modelId) + const usePatternC = isGeminiGenerateContentTtsModel(modelId) - const { base64, mimeType } = usePatternB - ? await executeTtsPatternB( - context.provider, - context.defaultHeaders, - text, - modelId, - normalizedModelConfig, - timeout - ) - : await executeTtsPatternA( + const { base64, mimeType } = usePatternC + ? await executeTtsPatternC( context.provider, context.defaultHeaders, text, @@ -594,6 +767,23 @@ export async function* runAiSdkCoreStream( normalizedModelConfig, timeout ) + : usePatternB + ? await executeTtsPatternB( + context.provider, + context.defaultHeaders, + text, + modelId, + normalizedModelConfig, + timeout + ) + : await executeTtsPatternA( + context.provider, + context.defaultHeaders, + text, + modelId, + normalizedModelConfig, + timeout + ) const dataUrl = `data:${mimeType};base64,${base64}` const cachedAudio = await presenter.devicePresenter.cacheImage(dataUrl) diff --git a/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts b/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts index 67380a16c..0cb7ec041 100644 --- a/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts @@ -7,7 +7,7 @@ import { resolveProviderCapabilityProviderId, type NewApiEndpointType } from '@shared/model' -import { isChatAudioTtsModel, isStandardTtsModel, isTtsModelConfig } from '@shared/ttsSettings' +import { isTtsModelConfig, isTtsModelId } from '@shared/ttsSettings' import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS, @@ -99,8 +99,7 @@ const shouldUseOpenAIImageGenerationRoute = (modelId: string, modelConfig: Model const shouldUseOpenAITtsRoute = (modelId: string, modelConfig: ModelConfig): boolean => isTtsModelConfig(modelConfig) || modelConfig.apiEndpoint === ApiEndpointType.AudioSpeech || - isStandardTtsModel(modelId) || - isChatAudioTtsModel(modelId) + isTtsModelId(modelId) export function normalizeExtractedImageText(content: string): string { const normalized = content diff --git a/src/shared/ttsSettings.ts b/src/shared/ttsSettings.ts index 7f50b1c86..d7f5255aa 100644 --- a/src/shared/ttsSettings.ts +++ b/src/shared/ttsSettings.ts @@ -13,12 +13,15 @@ export interface TtsSettings { /** * Standard OpenAI-style TTS models that use the /audio/speech endpoint (Pattern A). */ -export const OPENAI_STANDARD_TTS_MODELS = [ - 'tts-1', - 'tts-1-hd', - 'gpt-4o-mini-tts', +export const OPENAI_STANDARD_TTS_MODELS = ['tts-1', 'tts-1-hd', 'gpt-4o-mini-tts'] as const + +/** + * Gemini TTS models that use the generateContent endpoint with AUDIO output. + */ +export const GEMINI_GENERATE_CONTENT_TTS_MODELS = [ 'gemini-2.5-flash-preview-tts', - 'gemini-2.5-pro-preview-tts' + 'gemini-2.5-pro-preview-tts', + 'gemini-3.1-flash-tts-preview' ] as const /** @@ -42,6 +45,14 @@ export function isStandardTtsModel(modelId: string): boolean { return (OPENAI_STANDARD_TTS_MODELS as readonly string[]).includes(id) } +/** + * Returns true if the model uses the Gemini generateContent endpoint for TTS. + */ +export function isGeminiGenerateContentTtsModel(modelId: string): boolean { + const id = normalizeTtsModelId(modelId) + return (GEMINI_GENERATE_CONTENT_TTS_MODELS as readonly string[]).includes(id) +} + /** * Returns true if the model produces TTS audio via the chat completions endpoint (Pattern B). */ @@ -57,7 +68,11 @@ export function isChatAudioTtsModel(modelId: string): boolean { * Returns true if the model is any kind of TTS model (either pattern). */ export function isTtsModelId(modelId: string): boolean { - return isStandardTtsModel(modelId) || isChatAudioTtsModel(modelId) + return ( + isStandardTtsModel(modelId) || + isChatAudioTtsModel(modelId) || + isGeminiGenerateContentTtsModel(modelId) + ) } /** diff --git a/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts index e6481ff83..683749d26 100644 --- a/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts +++ b/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts @@ -1,4 +1,4 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' const { mockGenerateImage, @@ -31,7 +31,20 @@ vi.mock('@/presenter', () => ({ })) vi.mock('@/presenter/llmProviderPresenter/aiSdk/providerFactory', () => ({ - createAiSdkProviderContext: mockCreateAiSdkProviderContext + createAiSdkProviderContext: mockCreateAiSdkProviderContext, + normalizeGeminiBaseUrl: vi.fn((baseUrl?: string) => { + const normalized = (baseUrl || '').trim().replace(/\/+$/, '') + if (!normalized) { + return 'https://generativelanguage.googleapis.com/v1beta' + } + if (/\/v1beta1$/i.test(normalized) || /\/v1beta$/i.test(normalized)) { + return normalized + } + if (/\/v1$/i.test(normalized)) { + return normalized.replace(/\/v1$/i, '/v1beta') + } + return `${normalized}/v1beta` + }) })) import { @@ -73,6 +86,10 @@ describe('AI SDK runtime', () => { mockCacheImage.mockResolvedValue('cached://image') }) + afterEach(() => { + vi.unstubAllGlobals() + }) + it('builds image prompts from text-like content instead of object stringification', async () => { const context = { providerKind: 'openai-compatible', @@ -350,6 +367,197 @@ describe('AI SDK runtime', () => { expect(request).not.toHaveProperty('providerOptions') }) + it('includes an assistant role message for chat-audio TTS requests', async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + choices: [ + { + message: { + audio: { + data: 'ZmFrZS1hdWRpby1iYXNlNjQ=' + } + } + } + ] + }), + { + status: 200, + headers: { + 'Content-Type': 'application/json' + } + } + ) + ) + vi.stubGlobal('fetch', fetchMock) + + const context = { + providerKind: 'openai-compatible', + provider: { + id: 'xiaomimimo', + apiType: 'openai-compatible', + baseUrl: 'https://example.com/v1', + apiKey: 'test-key' + }, + configPresenter: {}, + defaultHeaders: {}, + shouldUseTts: () => true + } as any + + const events = [] + for await (const event of runAiSdkCoreStream( + context, + [{ role: 'user', content: 'hello tts' }], + 'mimo-v2.5-tts', + { + apiEndpoint: 'chat', + tts: { + responseFormat: 'wav', + voice: 'alloy' + } + } as any, + 0.7, + 1024, + [] + )) { + events.push(event) + } + + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(fetchMock.mock.calls[0]?.[0]).toBe('https://example.com/v1/chat/completions') + + const requestInit = fetchMock.mock.calls[0]?.[1] as RequestInit + const payload = JSON.parse(String(requestInit.body)) as { + messages?: Array<{ role?: string; content?: string }> + } + expect(payload.messages).toEqual([ + { role: 'user', content: 'hello tts' }, + { role: 'assistant', content: 'hello tts' } + ]) + + expect(events).toEqual([ + { + type: 'image_data', + image_data: { + data: 'cached://image', + mimeType: 'audio/wav' + } + }, + { + type: 'stop', + stop_reason: 'complete' + } + ]) + }) + + it('uses Gemini generateContent compatibility mode for AIHubMix Gemini TTS models', async () => { + const pcmBase64 = Buffer.from([0, 0, 255, 127]).toString('base64') + const fetchMock = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + candidates: [ + { + content: { + parts: [ + { + inlineData: { + mimeType: 'audio/L16;rate=24000', + data: pcmBase64 + } + } + ] + } + } + ] + }), + { + status: 200, + headers: { + 'Content-Type': 'application/json' + } + } + ) + ) + vi.stubGlobal('fetch', fetchMock) + + const context = { + providerKind: 'openai-compatible', + provider: { + id: 'aihubmix', + apiType: 'openai-compatible', + baseUrl: 'https://aihubmix.com/v1', + apiKey: 'test-key' + }, + configPresenter: {}, + defaultHeaders: { + 'APP-Code': 'SMUE7630' + }, + shouldUseTts: () => true + } as any + + const events = [] + for await (const event of runAiSdkCoreStream( + context, + [{ role: 'user', content: 'Have a wonderful day!' }], + 'gemini-2.5-flash-preview-tts', + { + apiEndpoint: 'audio-speech', + tts: { + voice: 'Kore', + instructions: 'Say cheerfully:' + } + } as any, + 0.7, + 1024, + [] + )) { + events.push(event) + } + + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(fetchMock.mock.calls[0]?.[0]).toBe( + 'https://aihubmix.com/gemini/v1beta/models/gemini-2.5-flash-preview-tts:generateContent' + ) + + const requestInit = fetchMock.mock.calls[0]?.[1] as RequestInit + const headers = new Headers(requestInit.headers) + expect(headers.get('x-goog-api-key')).toBe('test-key') + expect(headers.get('Authorization')).toBeNull() + + const payload = JSON.parse(String(requestInit.body)) as { + contents?: Array<{ parts?: Array<{ text?: string }> }> + generationConfig?: { + responseModalities?: string[] + speechConfig?: { + voiceConfig?: { + prebuiltVoiceConfig?: { + voiceName?: string + } + } + } + } + } + expect(payload.contents?.[0]?.parts?.[0]?.text).toBe('Say cheerfully:\n\nHave a wonderful day!') + expect(payload.generationConfig?.responseModalities).toEqual(['AUDIO']) + expect( + payload.generationConfig?.speechConfig?.voiceConfig?.prebuiltVoiceConfig?.voiceName + ).toBe('Kore') + + expect(events).toEqual([ + { + type: 'image_data', + image_data: { + data: 'cached://image', + mimeType: 'audio/wav' + } + }, + { + type: 'stop', + stop_reason: 'complete' + } + ]) + }) + it('omits temperature for anthropic models that disable temperature control', async () => { const tracePayloads: Array<{ body?: Record }> = [] const context = {