diff --git a/src/cortex-cli/src/agent_cmd/tests.rs b/src/cortex-cli/src/agent_cmd/tests.rs index e2ff07f9f..18f7ba753 100644 --- a/src/cortex-cli/src/agent_cmd/tests.rs +++ b/src/cortex-cli/src/agent_cmd/tests.rs @@ -3,10 +3,9 @@ #[cfg(test)] mod tests { use crate::agent_cmd::cli::{CopyArgs, ExportArgs}; - use crate::agent_cmd::loader::{ - load_builtin_agents, parse_frontmatter, read_file_with_encoding, - }; + use crate::agent_cmd::loader::{load_builtin_agents, parse_frontmatter}; use crate::agent_cmd::types::AgentMode; + use crate::utils::file::read_file_with_encoding; #[test] fn test_read_file_with_utf8() { diff --git a/src/cortex-cli/src/run_cmd/cli.rs b/src/cortex-cli/src/run_cmd/cli.rs index 9fcc0457b..72012b540 100644 --- a/src/cortex-cli/src/run_cmd/cli.rs +++ b/src/cortex-cli/src/run_cmd/cli.rs @@ -161,7 +161,7 @@ pub struct RunCli { /// Maximum tokens for response (used for token validation). /// If specified, cortex will validate that prompt + max_tokens /// does not exceed the model's context limit before making the API call. - #[arg(long = "max-tokens")] + #[arg(long = "max-tokens", value_parser = clap::value_parser!(u32).range(1..))] pub max_tokens: Option, /// Custom system prompt to use instead of the default. diff --git a/src/cortex-cli/src/run_cmd/execution.rs b/src/cortex-cli/src/run_cmd/execution.rs index 18bb58d5b..0b0d7790e 100644 --- a/src/cortex-cli/src/run_cmd/execution.rs +++ b/src/cortex-cli/src/run_cmd/execution.rs @@ -17,6 +17,18 @@ use super::output::{copy_to_clipboard, send_notification}; use super::session::{SessionMode, resolve_session_id}; use super::system::check_file_descriptor_limits; +#[derive(Debug, PartialEq, Eq)] +struct DryRunTokenEstimate { + user_prompt_tokens: u32, + attachment_tokens: u32, + system_prompt_tokens: u32, + tool_tokens: u32, + tool_count: u32, + total_input_tokens: u32, + max_response_tokens: Option, + total_with_max_response: Option, +} + impl RunCli { /// Run the command. pub async fn run(self) -> Result<()> { @@ -784,8 +796,6 @@ impl RunCli { /// Run in dry-run mode - show token estimates without executing. async fn run_dry_run(&self, message: &str, attachments: &[FileAttachment]) -> Result<()> { - use cortex_engine::tokenizer::TokenCounter; - let config = cortex_engine::Config::default(); let model = self .model @@ -793,34 +803,7 @@ impl RunCli { .map(|m| resolve_model_alias(m).to_string()) .unwrap_or_else(|| config.model.clone()); - let mut counter = TokenCounter::for_model(&model); - - // Count user prompt tokens - let user_prompt_tokens = counter.count(message); - - // Count attachment tokens - let mut attachment_tokens = 0u32; - for attachment in attachments { - let content = - std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new()); - attachment_tokens += counter.count(&content); - // Add overhead for file markers - attachment_tokens += 20; // Approximate overhead for "--- File: ... ---" markers - } - - // Estimate system prompt tokens (typical system prompt is ~500-2000 tokens) - // This is an approximation as the actual system prompt varies - let system_prompt_tokens = 1500u32; - - // Estimate tool definition tokens - // Each tool definition is approximately 100-200 tokens on average - // Common tools: Execute, Read, Write, Edit, LS, Grep, Glob, etc. - let tool_count = 15; // Approximate number of default tools - let tool_tokens = tool_count * 150; // ~150 tokens per tool definition - - // Calculate totals - let total_input_tokens = - user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens; + let estimate = self.estimate_dry_run_tokens(&model, message, attachments); // Output based on format if matches!(self.format, OutputFormat::Json | OutputFormat::Jsonl) { @@ -828,11 +811,13 @@ impl RunCli { "dry_run": true, "model": model, "token_estimates": { - "user_prompt": user_prompt_tokens, - "attachments": attachment_tokens, - "system_prompt": system_prompt_tokens, - "tool_definitions": tool_tokens, - "total_input": total_input_tokens, + "user_prompt": estimate.user_prompt_tokens, + "attachments": estimate.attachment_tokens, + "system_prompt": estimate.system_prompt_tokens, + "tool_definitions": estimate.tool_tokens, + "total_input": estimate.total_input_tokens, + "max_response": estimate.max_response_tokens, + "total_with_max_response": estimate.total_with_max_response, }, "message_preview": if message.len() > 100 { format!("{}...", &message[..100]) @@ -849,24 +834,36 @@ impl RunCli { println!("Model: {}", model); println!(); println!("Token Breakdown:"); - println!(" User prompt: {:>8} tokens", user_prompt_tokens); + println!( + " User prompt: {:>8} tokens", + estimate.user_prompt_tokens + ); if !attachments.is_empty() { println!( " Attachments: {:>8} tokens ({} files)", - attachment_tokens, + estimate.attachment_tokens, attachments.len() ); } println!( " System prompt: {:>8} tokens (estimated)", - system_prompt_tokens + estimate.system_prompt_tokens ); println!( " Tool definitions: {:>8} tokens (estimated, {} tools)", - tool_tokens, tool_count + estimate.tool_tokens, estimate.tool_count ); println!(" {}", "-".repeat(30)); - println!(" Total input: {:>8} tokens", total_input_tokens); + println!( + " Total input: {:>8} tokens", + estimate.total_input_tokens + ); + if let Some(max_tokens) = estimate.max_response_tokens { + println!(" Max response: {:>8} tokens", max_tokens); + if let Some(total_with_max_response) = estimate.total_with_max_response { + println!(" Input + response: {:>8} tokens", total_with_max_response); + } + } println!(); println!("Note: System prompt and tool definition token counts are estimates."); println!("Actual counts may vary based on agent configuration."); @@ -884,4 +881,84 @@ impl RunCli { Ok(()) } + + fn estimate_dry_run_tokens( + &self, + model: &str, + message: &str, + attachments: &[FileAttachment], + ) -> DryRunTokenEstimate { + use cortex_engine::tokenizer::TokenCounter; + + let mut counter = TokenCounter::for_model(&model); + + // 统计用户提示词 token。 + let user_prompt_tokens = counter.count(message); + + // 统计附件 token。 + let mut attachment_tokens = 0u32; + for attachment in attachments { + let content = + std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new()); + attachment_tokens += counter.count(&content); + // 加上文件标记的近似开销。 + attachment_tokens += 20; + } + + // 系统提示词会随配置变化,这里使用常见范围内的近似值。 + let system_prompt_tokens = 1500u32; + + // 工具定义按默认工具数量和单个工具的平均 token 估算。 + let tool_count = 15; + let tool_tokens = tool_count * 150; + + // 计算输入总量,并在传入 max_tokens 时纳入响应上限。 + let total_input_tokens = + user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens; + let total_with_max_response = self + .max_tokens + .map(|max| total_input_tokens.saturating_add(max)); + + DryRunTokenEstimate { + user_prompt_tokens, + attachment_tokens, + system_prompt_tokens, + tool_tokens, + tool_count, + total_input_tokens, + max_response_tokens: self.max_tokens, + total_with_max_response, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + + #[test] + fn dry_run_estimate_includes_max_response_tokens() { + let cli = RunCli::try_parse_from(["run", "--dry-run", "--max-tokens", "4096", "Long task"]) + .expect("max tokens above zero should parse"); + + let estimate = cli.estimate_dry_run_tokens("gpt-4o", "Long task", &[]); + + assert_eq!(estimate.max_response_tokens, Some(4096)); + assert_eq!( + estimate.total_with_max_response, + Some(estimate.total_input_tokens + 4096) + ); + } + + #[test] + fn run_rejects_zero_max_tokens() { + let error = RunCli::try_parse_from(["run", "--max-tokens", "0", "Long task"]) + .expect_err("zero max tokens should be rejected"); + + assert!( + error.to_string().contains("invalid value"), + "unexpected error: {error}" + ); + } }