diff --git a/codex-cli/src/cli.tsx b/codex-cli/src/cli.tsx index c7e5d9ff3..d24672a99 100644 --- a/codex-cli/src/cli.tsx +++ b/codex-cli/src/cli.tsx @@ -43,6 +43,7 @@ import { } from "./utils/get-api-key"; import { createInputItem } from "./utils/input-utils"; import { initLogger } from "./utils/logger/log"; +import { getModelForProvider } from "./utils/model-selection"; import { isModelSupportedForResponses } from "./utils/model-utils.js"; import { parseToolCall } from "./utils/parsers"; import { onExit, setInkRenderer } from "./utils/terminal"; @@ -289,10 +290,16 @@ let config = loadConfig(undefined, undefined, { // via the `--history` flag. Therefore it must be declared with `let` rather // than `const`. let prompt = cli.input[0]; -const model = cli.flags.model ?? config.model; -const imagePaths = cli.flags.image; const provider = cli.flags.provider ?? config.provider ?? "openai"; +// Get the appropriate model for the selected provider +const model = getModelForProvider(config, { + model: cli.flags.model, + provider, +}); + +const imagePaths = cli.flags.image; + const client = { issuer: "https://auth.openai.com", client_id: "app_EMoamEEZ73f0CkXaXp7hrann", diff --git a/codex-cli/src/components/chat/terminal-chat.tsx b/codex-cli/src/components/chat/terminal-chat.tsx index d41a94990..03e1cef92 100644 --- a/codex-cli/src/components/chat/terminal-chat.tsx +++ b/codex-cli/src/components/chat/terminal-chat.tsx @@ -639,11 +639,17 @@ export default function TerminalChat({ prev && newModel !== model ? null : prev, ); - // Save model to config + // Save model to config and update the default model for this provider + const providerDefaultModels = { + ...(config.providerDefaultModels || {}), + [provider]: newModel, + }; + saveConfig({ ...config, model: newModel, provider: provider, + providerDefaultModels, }); setItems((prev) => [ @@ -674,7 +680,8 @@ export default function TerminalChat({ setLoading(false); // Select default model for the new provider. - const defaultModel = model; + const defaultModel = + config.providerDefaultModels?.[newProvider] || model; // Save provider to config. const updatedConfig = { diff --git a/codex-cli/src/utils/config.ts b/codex-cli/src/utils/config.ts index 51761bf6d..5c628a732 100644 --- a/codex-cli/src/utils/config.ts +++ b/codex-cli/src/utils/config.ts @@ -149,7 +149,18 @@ export type StoredConfig = { /** Disable server-side response storage (send full transcript each request) */ disableResponseStorage?: boolean; flexMode?: boolean; - providers?: Record; + providers?: Record< + string, + { + name: string; + baseURL: string; + envKey: string; + /** Default model to use for this provider */ + defaultModel?: string; + } + >; + /** Map of provider IDs to their default models */ + providerDefaultModels?: Record; history?: { maxSize?: number; saveHistory?: boolean; @@ -202,7 +213,18 @@ export type AppConfig = { /** Enable the "flex-mode" processing mode for supported models (o3, o4-mini) */ flexMode?: boolean; - providers?: Record; + providers?: Record< + string, + { + name: string; + baseURL: string; + envKey: string; + /** Default model to use for this provider */ + defaultModel?: string; + } + >; + /** Map of provider IDs to their default models */ + providerDefaultModels?: Record; history?: { maxSize: number; saveHistory: boolean; @@ -439,6 +461,7 @@ export const loadConfig = ( disableResponseStorage: storedConfig.disableResponseStorage === true, reasoningEffort: storedConfig.reasoningEffort, fileOpener: storedConfig.fileOpener, + providerDefaultModels: storedConfig.providerDefaultModels, }; // ----------------------------------------------------------------------- @@ -560,6 +583,7 @@ export const saveConfig = ( disableResponseStorage: config.disableResponseStorage, flexMode: config.flexMode, reasoningEffort: config.reasoningEffort, + providerDefaultModels: config.providerDefaultModels, }; // Add history settings if they exist diff --git a/codex-cli/src/utils/model-selection.ts b/codex-cli/src/utils/model-selection.ts new file mode 100644 index 000000000..7645ad9cd --- /dev/null +++ b/codex-cli/src/utils/model-selection.ts @@ -0,0 +1,34 @@ +import type { AppConfig } from "./config"; + +/** + * Get the appropriate model for the selected provider + * + * @param config The application configuration + * @param cliFlags The CLI flags containing model and provider information + * @returns The model to use + */ +export function getModelForProvider( + config: AppConfig, + cliFlags: { model?: string; provider?: string }, +): string { + // CLI model flag takes precedence + if (cliFlags.model) { + return cliFlags.model; + } + + // If provider is specified and there's a default model for it in providerDefaultModels, use that + if (cliFlags.provider && config.providerDefaultModels?.[cliFlags.provider]) { + return config.providerDefaultModels[cliFlags.provider] as string; + } + + // If provider is specified and there's a provider config with defaultModel, use that + if ( + cliFlags.provider && + config.providers?.[cliFlags.provider]?.defaultModel + ) { + return config.providers[cliFlags.provider]?.defaultModel as string; + } + + // Fall back to global default model + return config.model; +} diff --git a/codex-cli/tests/config.test.tsx b/codex-cli/tests/config.test.tsx index 55c2297fc..030bad364 100644 --- a/codex-cli/tests/config.test.tsx +++ b/codex-cli/tests/config.test.tsx @@ -361,3 +361,268 @@ test("loads and saves custom shell config", () => { expect(reloadedConfig.tools?.shell?.maxBytes).toBe(updatedMaxBytes); expect(reloadedConfig.tools?.shell?.maxLines).toBe(updatedMaxLines); }); + +test("loads and saves provider default models correctly", () => { + // Setup config with provider default models + const providerDefaultModels = { + openai: "o4-mini", + ollama: "qwen2.5-coder:14b", + mistral: "mistral-large-latest", + }; + + memfs[testConfigPath] = JSON.stringify( + { + model: "mymodel", + providerDefaultModels, + }, + null, + 2, + ); + memfs[testInstructionsPath] = "test instructions"; + + // Load config and verify provider default models + const loadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + // Check provider default models were loaded correctly + expect(loadedConfig.providerDefaultModels).toEqual(providerDefaultModels); + + // Modify provider default models and save + const updatedProviderDefaultModels = { + ...providerDefaultModels, + openai: "gpt-4.1", + groq: "llama3-70b", + }; + + const updatedConfig = { + ...loadedConfig, + providerDefaultModels: updatedProviderDefaultModels, + }; + + saveConfig(updatedConfig, testConfigPath, testInstructionsPath); + + // Verify saved config contains updated provider default models + expect(memfs[testConfigPath]).toContain(`"providerDefaultModels"`); + expect(memfs[testConfigPath]).toContain(`"gpt-4.1"`); + expect(memfs[testConfigPath]).toContain(`"llama3-70b"`); + + // Load again and verify updated values + const reloadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + expect(reloadedConfig.providerDefaultModels).toEqual( + updatedProviderDefaultModels, + ); +}); + +test("handles empty providerDefaultModels correctly", () => { + // Setup config with empty providerDefaultModels + memfs[testConfigPath] = JSON.stringify( + { + model: "mymodel", + providerDefaultModels: {}, + }, + null, + 2, + ); + memfs[testInstructionsPath] = "test instructions"; + + // Load config and verify empty providerDefaultModels + const loadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + // Check providerDefaultModels is an empty object + expect(loadedConfig.providerDefaultModels).toEqual({}); + + // Add provider default models and save + const updatedConfig = { + ...loadedConfig, + providerDefaultModels: { + openai: "o4-mini", + ollama: "qwen2.5-coder:14b", + }, + }; + + saveConfig(updatedConfig, testConfigPath, testInstructionsPath); + + // Load again and verify updated values + const reloadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + expect(reloadedConfig.providerDefaultModels).toEqual({ + openai: "o4-mini", + ollama: "qwen2.5-coder:14b", + }); +}); + +test("loads and saves provider with defaultModel property correctly", () => { + // Setup config with providers that have defaultModel property + const customProviders = { + openai: { + name: "OpenAI", + baseURL: "https://api.openai.com/v1", + envKey: "OPENAI_API_KEY", + defaultModel: "o4-mini", + }, + ollama: { + name: "Ollama", + baseURL: "http://localhost:11434/v1", + envKey: "OLLAMA_API_KEY", + defaultModel: "qwen2.5-coder:14b", + }, + }; + + memfs[testConfigPath] = JSON.stringify( + { + model: "mymodel", + providers: customProviders, + }, + null, + 2, + ); + memfs[testInstructionsPath] = "test instructions"; + + // Load config and verify providers with defaultModel + const loadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + // Check providers were loaded correctly with defaultModel + expect(loadedConfig.providers?.["openai"]?.defaultModel).toBe("o4-mini"); + expect(loadedConfig.providers?.["ollama"]?.defaultModel).toBe( + "qwen2.5-coder:14b", + ); + + // Modify providers and save + const updatedProviders = { + ...loadedConfig.providers, + openai: { + name: "OpenAI", + baseURL: "https://api.openai.com/v1", + envKey: "OPENAI_API_KEY", + defaultModel: "gpt-4.1", + }, + mistral: { + name: "Mistral", + baseURL: "https://api.mistral.ai/v1", + envKey: "MISTRAL_API_KEY", + defaultModel: "mistral-large-latest", + }, + }; + + const updatedConfig = { + ...loadedConfig, + providers: updatedProviders, + }; + + saveConfig(updatedConfig, testConfigPath, testInstructionsPath); + + // Verify saved config contains updated providers with defaultModel + expect(memfs[testConfigPath]).toContain(`"defaultModel": "gpt-4.1"`); + expect(memfs[testConfigPath]).toContain( + `"defaultModel": "mistral-large-latest"`, + ); + + // Load again and verify updated values + const reloadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + expect(reloadedConfig.providers?.["openai"]?.defaultModel).toBe("gpt-4.1"); + expect(reloadedConfig.providers?.["mistral"]?.defaultModel).toBe( + "mistral-large-latest", + ); + expect(reloadedConfig.providers?.["ollama"]?.defaultModel).toBe( + "qwen2.5-coder:14b", + ); +}); + +test("handles both providerDefaultModels and provider.defaultModel correctly", () => { + // Setup config with both providerDefaultModels and provider.defaultModel + const providerDefaultModels = { + openai: "o4-mini", + ollama: "qwen2.5-coder:14b", + }; + + const customProviders = { + openai: { + name: "OpenAI", + baseURL: "https://api.openai.com/v1", + envKey: "OPENAI_API_KEY", + defaultModel: "gpt-4.1", // This should take precedence in the provider object + }, + mistral: { + name: "Mistral", + baseURL: "https://api.mistral.ai/v1", + envKey: "MISTRAL_API_KEY", + defaultModel: "mistral-large-latest", + }, + }; + + memfs[testConfigPath] = JSON.stringify( + { + model: "mymodel", + providerDefaultModels, + providers: customProviders, + }, + null, + 2, + ); + memfs[testInstructionsPath] = "test instructions"; + + // Load config and verify both types of default models + const loadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + // Check both types of default models were loaded correctly + expect(loadedConfig.providerDefaultModels?.["openai"]).toBe("o4-mini"); + expect(loadedConfig.providerDefaultModels?.["ollama"]).toBe( + "qwen2.5-coder:14b", + ); + expect(loadedConfig.providers?.["openai"]?.defaultModel).toBe("gpt-4.1"); + expect(loadedConfig.providers?.["mistral"]?.defaultModel).toBe( + "mistral-large-latest", + ); +}); + +test("handles missing providerDefaultModels correctly", () => { + // Setup config without providerDefaultModels + memfs[testConfigPath] = JSON.stringify( + { + model: "mymodel", + }, + null, + 2, + ); + memfs[testInstructionsPath] = "test instructions"; + + // Load config and verify providerDefaultModels is undefined + const loadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + // Check providerDefaultModels is undefined + expect(loadedConfig.providerDefaultModels).toBeUndefined(); + + // Add provider default models and save + const updatedConfig = { + ...loadedConfig, + providerDefaultModels: { + openai: "o4-mini", + }, + }; + + saveConfig(updatedConfig, testConfigPath, testInstructionsPath); + + // Load again and verify updated values + const reloadedConfig = loadConfig(testConfigPath, testInstructionsPath, { + disableProjectDoc: true, + }); + + expect(reloadedConfig.providerDefaultModels?.["openai"]).toBe("o4-mini"); +});