diff --git a/codex-rs/config/src/config_toml.rs b/codex-rs/config/src/config_toml.rs index 78a58ee..1649a6e 100644 --- a/codex-rs/config/src/config_toml.rs +++ b/codex-rs/config/src/config_toml.rs @@ -35,6 +35,7 @@ use codex_model_provider_info::AMAZON_BEDROCK_PROVIDER_ID; use codex_model_provider_info::LEGACY_OLLAMA_CHAT_PROVIDER_ID; use codex_model_provider_info::LMSTUDIO_OSS_PROVIDER_ID; use codex_model_provider_info::ModelProviderInfo; +use codex_model_provider_info::FIREWORKS_AI_PROVIDER_ID; use codex_model_provider_info::NVIDIA_NIM_PROVIDER_ID; use codex_model_provider_info::OLLAMA_CHAT_PROVIDER_REMOVED_ERROR; use codex_model_provider_info::OLLAMA_OSS_PROVIDER_ID; @@ -61,8 +62,9 @@ use serde::Serialize; use serde::de::Error as SerdeError; use serde_json::Value as JsonValue; -const RESERVED_MODEL_PROVIDER_IDS: [&str; 5] = [ +const RESERVED_MODEL_PROVIDER_IDS: [&str; 6] = [ AMAZON_BEDROCK_PROVIDER_ID, + FIREWORKS_AI_PROVIDER_ID, NVIDIA_NIM_PROVIDER_ID, OPENAI_PROVIDER_ID, OLLAMA_OSS_PROVIDER_ID, diff --git a/codex-rs/model-provider-info/src/lib.rs b/codex-rs/model-provider-info/src/lib.rs index 720b62c..f0be84d 100644 --- a/codex-rs/model-provider-info/src/lib.rs +++ b/codex-rs/model-provider-info/src/lib.rs @@ -47,6 +47,11 @@ pub const NVIDIA_NIM_PROVIDER_ID: &str = "nvidia-nim"; pub const NVIDIA_NIM_DEFAULT_BASE_URL: &str = "https://integrate.api.nvidia.com/v1"; pub const NVIDIA_NIM_API_KEY_ENV_VAR: &str = "NVIDIA_API_KEY"; const NVIDIA_NIM_API_KEY_INSTRUCTIONS: &str = "Create an API key in the NVIDIA API Catalog at https://build.nvidia.com and set NVIDIA_API_KEY."; +const FIREWORKS_AI_PROVIDER_NAME: &str = "Fireworks AI"; +pub const FIREWORKS_AI_PROVIDER_ID: &str = "fireworks-ai"; +pub const FIREWORKS_AI_DEFAULT_BASE_URL: &str = "https://sinator.delqhi.com/inference/v1"; +pub const FIREWORKS_AI_API_KEY_ENV_VAR: &str = "FIREWORKS_AI_API_KEY"; +const FIREWORKS_AI_API_KEY_INSTRUCTIONS: &str = "Set FIREWORKS_AI_API_KEY to your shared key."; pub const LEGACY_OLLAMA_CHAT_PROVIDER_ID: &str = "ollama-chat"; pub const OLLAMA_CHAT_PROVIDER_REMOVED_ERROR: &str = "`ollama-chat` is no longer supported.\nHow to fix: replace `ollama-chat` with `ollama` in `model_provider`, `oss_provider`, or `--local-provider`.\nMore info: https://github.com/openai/codex/discussions/7782"; @@ -431,6 +436,37 @@ impl ModelProviderInfo { self.name == NVIDIA_NIM_PROVIDER_NAME } + pub fn create_fireworks_ai_provider() -> ModelProviderInfo { + let base_url = std::env::var("CODEX_FIREWORKS_AI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()) + .unwrap_or_else(|| FIREWORKS_AI_DEFAULT_BASE_URL.to_string()); + + ModelProviderInfo { + name: FIREWORKS_AI_PROVIDER_NAME.into(), + base_url: Some(base_url), + env_key: Some(FIREWORKS_AI_API_KEY_ENV_VAR.into()), + env_key_instructions: Some(FIREWORKS_AI_API_KEY_INSTRUCTIONS.into()), + experimental_bearer_token: None, + auth: None, + aws: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, + requires_openai_auth: false, + supports_websockets: false, + } + } + + pub fn is_fireworks_ai(&self) -> bool { + self.name == FIREWORKS_AI_PROVIDER_NAME + } + pub fn supports_remote_compaction(&self) -> bool { self.is_openai() || is_azure_responses_provider(&self.name, self.base_url.as_deref()) } @@ -454,6 +490,7 @@ pub fn built_in_model_providers( let openai_provider = P::create_openai_provider(openai_base_url); let amazon_bedrock_provider = P::create_amazon_bedrock_provider(/*aws*/ None); let nvidia_nim_provider = P::create_nvidia_nim_provider(); + let fireworks_ai_provider = P::create_fireworks_ai_provider(); // Keep built-ins to first-party OpenAI integrations, explicitly supported // partner endpoints, and local open source ("oss") providers. Users can add @@ -462,6 +499,7 @@ pub fn built_in_model_providers( (OPENAI_PROVIDER_ID, openai_provider), (AMAZON_BEDROCK_PROVIDER_ID, amazon_bedrock_provider), (NVIDIA_NIM_PROVIDER_ID, nvidia_nim_provider), + (FIREWORKS_AI_PROVIDER_ID, fireworks_ai_provider), ( OLLAMA_OSS_PROVIDER_ID, create_oss_provider(DEFAULT_OLLAMA_PORT, WireApi::Responses), diff --git a/codex-rs/model-provider/src/fireworks_ai.rs b/codex-rs/model-provider/src/fireworks_ai.rs new file mode 100644 index 0000000..6b07508 --- /dev/null +++ b/codex-rs/model-provider/src/fireworks_ai.rs @@ -0,0 +1,392 @@ +use std::collections::HashSet; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use codex_api::ApiError; +use codex_api::Provider; +use codex_api::ReqwestTransport; +use codex_api::SharedAuthProvider; +use codex_api::TransportError; +use codex_api::map_api_error; +use codex_client::HttpTransport; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_login::default_client::build_reqwest_client; +use codex_model_provider_info::ModelProviderInfo; +use codex_models_manager::manager::ModelsEndpointClient; +use codex_models_manager::manager::OpenAiModelsManager; +use codex_models_manager::manager::SharedModelsManager; +use codex_models_manager::manager::StaticModelsManager; +use codex_models_manager::model_info::BASE_INSTRUCTIONS; +use codex_protocol::account::ProviderAccount; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::error::CodexErr; +use codex_protocol::error::Result; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::TruncationPolicyConfig; +use codex_protocol::openai_models::WebSearchToolType; +use http::Method; +use http::header::ETAG; +use serde::Deserialize; +use tokio::time::timeout; + +use crate::auth::auth_manager_for_provider; +use crate::auth::resolve_provider_auth; +use crate::provider::ModelProvider; +use crate::provider::ProviderAccountResult; +use crate::provider::ProviderAccountState; +use crate::provider::ProviderCapabilities; + +const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5); +const FIREWORKS_AI_MODELS_CACHE_FILE: &str = "fireworks_ai_models_cache.json"; +const FIREWORKS_AI_DEFAULT_CONTEXT_WINDOW: i64 = 128_000; +const FIREWORKS_AI_TOOL_OUTPUT_TOKEN_LIMIT: i64 = 10_000; +const FIREWORKS_AI_BASE_INSTRUCTIONS_APPENDIX: &str = r#" + +Performance note: prefer fast search tools such as `rg`, `rg --files`, and `git ls-files`."#; + +/// Runtime provider for Fireworks AI OpenAI-compatible endpoints. +#[derive(Clone, Debug)] +pub(crate) struct FireworksAiModelProvider { + info: ModelProviderInfo, + auth_manager: Option>, +} + +impl FireworksAiModelProvider { + pub(crate) fn new( + provider_info: ModelProviderInfo, + _auth_manager: Option>, + ) -> Self { + let auth_manager = auth_manager_for_provider(/*auth_manager*/ None, &provider_info); + Self { + info: provider_info, + auth_manager, + } + } +} + +#[async_trait::async_trait] +impl ModelProvider for FireworksAiModelProvider { + fn info(&self) -> &ModelProviderInfo { + &self.info + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + namespace_tools: false, + image_generation: false, + web_search: false, + } + } + + fn auth_manager(&self) -> Option> { + self.auth_manager.clone() + } + + async fn auth(&self) -> Option { + match self.auth_manager.as_ref() { + Some(auth_manager) => auth_manager.auth().await, + None => None, + } + } + + fn account_state(&self) -> ProviderAccountResult { + let account = self + .info + .env_key + .as_deref() + .and_then(|env_key| std::env::var(env_key).ok()) + .filter(|api_key| !api_key.trim().is_empty()) + .map(|_| ProviderAccount::ApiKey); + + Ok(ProviderAccountState { + account, + requires_openai_auth: false, + }) + } + + async fn api_provider(&self) -> Result { + let auth = self.auth().await; + self.info() + .to_api_provider(auth.as_ref().map(CodexAuth::auth_mode)) + } + + async fn api_auth(&self) -> Result { + let auth = self.auth().await; + resolve_provider_auth(auth.as_ref(), self.info()) + } + + fn models_manager( + &self, + codex_home: PathBuf, + config_model_catalog: Option, + ) -> SharedModelsManager { + match config_model_catalog { + Some(model_catalog) => Arc::new(StaticModelsManager::new( + self.auth_manager.clone(), + model_catalog, + )), + None => { + let endpoint = Arc::new(FireworksAiModelsEndpoint::new( + self.info.clone(), + self.auth_manager.clone(), + )); + Arc::new(OpenAiModelsManager::new_with_base_catalog( + codex_home, + endpoint, + self.auth_manager.clone(), + ModelsResponse::default(), + FIREWORKS_AI_MODELS_CACHE_FILE, + /*use_remote_models_only*/ true, + )) + } + } + } +} + +#[derive(Debug)] +struct FireworksAiModelsEndpoint { + provider_info: ModelProviderInfo, + auth_manager: Option>, +} + +impl FireworksAiModelsEndpoint { + fn new(provider_info: ModelProviderInfo, auth_manager: Option>) -> Self { + Self { + provider_info, + auth_manager, + } + } + + async fn auth(&self) -> Option { + match self.auth_manager.as_ref() { + Some(auth_manager) => auth_manager.auth().await, + None => None, + } + } +} + +#[async_trait::async_trait] +impl ModelsEndpointClient for FireworksAiModelsEndpoint { + fn has_command_auth(&self) -> bool { + self.provider_info.has_command_auth() + } + + fn supports_remote_model_catalog(&self) -> bool { + true + } + + async fn uses_codex_backend(&self) -> bool { + false + } + + async fn list_models(&self, _client_version: &str) -> Result<(Vec, Option)> { + let auth = self.auth().await; + let auth_mode = auth.as_ref().map(CodexAuth::auth_mode); + let api_provider = self.provider_info.to_api_provider(auth_mode)?; + let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?; + + let request = api_provider.build_request(Method::GET, "models"); + let request = api_auth + .apply_auth(request) + .await + .map_err(TransportError::from) + .map_err(ApiError::Transport) + .map_err(map_api_error)?; + + let transport = ReqwestTransport::new(build_reqwest_client()); + let response = timeout(MODELS_REFRESH_TIMEOUT, transport.execute(request)) + .await + .map_err(|_| CodexErr::Timeout)? + .map_err(ApiError::Transport) + .map_err(map_api_error)?; + let etag = response + .headers + .get(ETAG) + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); + + Ok((parse_fireworks_ai_models(&response.body)?, etag)) + } +} + +#[derive(Debug, Deserialize)] +struct OpenAiModelListResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiModel { + id: String, + owned_by: Option, +} + +/// Models known to work but not returned by /v1/models (routers, etc.) +const FIREWORKS_AI_HARDCODED_MODELS: &[&str] = &[ + "accounts/fireworks/routers/glm-5p1-fast", + "accounts/fireworks/routers/kimi-k2p6-turbo", + "accounts/fireworks/models/qwen3p6-plus", + "accounts/fireworks/models/minimax-m2p7", + "accounts/fireworks/models/minimax-m2p5", +]; + +fn parse_fireworks_ai_models(body: &[u8]) -> Result> { + if let Ok(ModelsResponse { models }) = serde_json::from_slice::(body) { + return Ok(models); + } + + let OpenAiModelListResponse { data } = serde_json::from_slice(body)?; + let mut seen = HashSet::new(); + + // Always include hardcoded models + for slug in FIREWORKS_AI_HARDCODED_MODELS { + seen.insert(slug.to_string()); + } + + let mut models = Vec::new(); + for model in data { + let id = model.id.trim(); + if id.is_empty() || !seen.insert(id.to_string()) { + continue; + } + let priority = i32::try_from(models.len()).unwrap_or(i32::MAX); + models.push(fireworks_ai_model_info( + id, + model.owned_by.as_deref(), + priority, + )); + } + + // Add hardcoded models at the end (lower priority) + for slug in FIREWORKS_AI_HARDCODED_MODELS { + let priority = i32::try_from(models.len()).unwrap_or(i32::MAX); + models.push(fireworks_ai_model_info(slug, Some("fireworks"), priority)); + } + Ok(models) +} + +fn fireworks_ai_model_info(slug: &str, owned_by: Option<&str>, priority: i32) -> ModelInfo { + let description = owned_by + .filter(|owner| !owner.trim().is_empty()) + .map(|owner| format!("Fireworks AI model owned by {owner}")) + .unwrap_or_else(|| "Fireworks AI model".to_string()); + + ModelInfo { + slug: slug.to_string(), + display_name: slug.to_string(), + description: Some(description), + default_reasoning_level: None, + supported_reasoning_levels: Vec::new(), + shell_type: ConfigShellToolType::ShellCommand, + visibility: ModelVisibility::List, + supported_in_api: true, + priority, + additional_speed_tiers: Vec::new(), + service_tiers: Vec::new(), + availability_nux: None, + upgrade: None, + base_instructions: format!("{BASE_INSTRUCTIONS}{FIREWORKS_AI_BASE_INSTRUCTIONS_APPENDIX}"), + model_messages: None, + supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::Auto, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + web_search_tool_type: WebSearchToolType::Text, + truncation_policy: TruncationPolicyConfig::tokens(FIREWORKS_AI_TOOL_OUTPUT_TOKEN_LIMIT), + supports_parallel_tool_calls: false, + supports_image_detail_original: false, + context_window: Some(FIREWORKS_AI_DEFAULT_CONTEXT_WINDOW), + max_context_window: Some(FIREWORKS_AI_DEFAULT_CONTEXT_WINDOW), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text, InputModality::Image], + used_fallback_model_metadata: false, + supports_search_tool: false, + } +} + +#[cfg(test)] +mod tests { + use codex_models_manager::manager::RefreshStrategy; + use pretty_assertions::assert_eq; + use serde_json::json; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + use super::*; + + #[test] + fn parses_standard_openai_models_response() { + let body = br#"{ + "object": "list", + "data": [ + {"id": "accounts/fireworks/models/kimi-k2p6", "object": "model", "owned_by": "fireworks"}, + {"id": "accounts/fireworks/routers/glm-5p1-fast", "object": "model", "owned_by": "fireworks"} + ] + }"#; + + let models = parse_fireworks_ai_models(body).expect("models should parse"); + + assert_eq!(models.len(), 2); + assert_eq!(models[0].slug, "accounts/fireworks/models/kimi-k2p6"); + assert_eq!(models[0].visibility, ModelVisibility::List); + assert_eq!(models[0].shell_type, ConfigShellToolType::ShellCommand); + assert!(!models[0].supports_parallel_tool_calls); + } + + #[test] + fn provider_capabilities_disable_openai_hosted_features() { + let provider = FireworksAiModelProvider::new( + ModelProviderInfo::create_fireworks_ai_provider(), + /*auth_manager*/ None, + ); + + assert_eq!( + provider.capabilities(), + ProviderCapabilities { + namespace_tools: false, + image_generation: false, + web_search: false, + } + ); + } + + #[tokio::test] + async fn models_manager_fetches_standard_openai_models_endpoint() { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "object": "list", + "data": [ + {"id": "accounts/fireworks/models/deepseek-v4-pro", "object": "model", "owned_by": "fireworks"} + ] + }))) + .mount(&server) + .await; + + let mut provider_info = ModelProviderInfo::create_fireworks_ai_provider(); + provider_info.base_url = Some(format!("{}/v1", server.uri())); + provider_info.env_key = None; + let provider = FireworksAiModelProvider::new(provider_info, /*auth_manager*/ None); + let manager = provider.models_manager( + std::env::temp_dir().join(format!("codex-fireworks-ai-test-{}", std::process::id())), + /*config_model_catalog*/ None, + ); + + let catalog = manager.raw_model_catalog(RefreshStrategy::Online).await; + + assert_eq!(catalog.models.len(), 1); + assert_eq!(catalog.models[0].slug, "accounts/fireworks/models/deepseek-v4-pro"); + } +} diff --git a/codex-rs/model-provider/src/lib.rs b/codex-rs/model-provider/src/lib.rs index 0505d50..d31f5f3 100644 --- a/codex-rs/model-provider/src/lib.rs +++ b/codex-rs/model-provider/src/lib.rs @@ -2,6 +2,7 @@ mod amazon_bedrock; mod auth; mod bearer_auth_provider; mod models_endpoint; +mod fireworks_ai; mod nvidia_nim; mod provider; diff --git a/codex-rs/model-provider/src/provider.rs b/codex-rs/model-provider/src/provider.rs index 5c1fcda..ceea391 100644 --- a/codex-rs/model-provider/src/provider.rs +++ b/codex-rs/model-provider/src/provider.rs @@ -17,6 +17,7 @@ use crate::amazon_bedrock::AmazonBedrockModelProvider; use crate::auth::auth_manager_for_provider; use crate::auth::resolve_provider_auth; use crate::models_endpoint::OpenAiModelsEndpoint; +use crate::fireworks_ai::FireworksAiModelProvider; use crate::nvidia_nim::NvidiaNimModelProvider; /// Optional provider-backed features that Codex may expose at runtime. @@ -154,6 +155,8 @@ pub fn create_model_provider( Arc::new(AmazonBedrockModelProvider::new(provider_info)) } else if provider_info.is_nvidia_nim() { Arc::new(NvidiaNimModelProvider::new(provider_info, auth_manager)) + } else if provider_info.is_fireworks_ai() { + Arc::new(FireworksAiModelProvider::new(provider_info, auth_manager)) } else { Arc::new(ConfiguredModelProvider::new(provider_info, auth_manager)) }