diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 759604dd93..3ef61bff90 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -144,6 +144,7 @@ def __init__( self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3) self.compression_threshold = compression_threshold self.token_counter = token_counter or EstimateTokenCounter() + self.last_call_failed = False self.instruction_text = instruction_text or ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" @@ -212,6 +213,8 @@ async def __call__(self, messages: list[Message]) -> list[Message]: """ from .round_utils import split_into_rounds + self.last_call_failed = False + rounds = split_into_rounds(messages) message_rounds = [ [seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds @@ -276,13 +279,19 @@ async def __call__(self, messages: list[Message]) -> list[Message]: response = await self.provider.text_chat( contexts=sanitized_summary_contexts, ) + if response.role == "err": + logger.error(f"Failed to generate summary: {response.completion_text}") + self.last_call_failed = True + return messages summary_content = (response.completion_text or "").strip() except Exception as e: logger.error(f"Failed to generate summary: {e}") + self.last_call_failed = True return messages if not summary_content: logger.warning("LLM context compression returned an empty summary.") + self.last_call_failed = True return messages # Build result: system messages + summary pair + recent rounds diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 1a11ebff96..f324b33184 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -3,6 +3,7 @@ from ..message import Message from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor from .config import ContextConfig +from .round_utils import count_conversation_rounds from .token_counter import EstimateTokenCounter from .truncator import ContextTruncator @@ -53,18 +54,50 @@ async def process( Returns: The processed message list. """ + result, _ = await self.process_with_meta(messages, trusted_token_usage) + return result + + async def process_with_meta( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> tuple[list[Message], bool]: try: result = messages + was_hard_truncated = False - # 1. 基于轮次的截断 (Enforce max turns) if self.config.enforce_max_turns != -1: - result = self.truncator.truncate_by_turns( - result, - keep_most_recent_turns=self.config.enforce_max_turns, - drop_turns=self.config.truncate_turns, - ) + turn_count = count_conversation_rounds(result) + if turn_count > self.config.enforce_max_turns: + if isinstance(self.compressor, LLMSummaryCompressor): + logger.debug( + "Turn limit (%s) exceeded (%s turns), " + "delegating to LLM summary compressor instead of " + "hard truncation.", + self.config.enforce_max_turns, + turn_count, + ) + compressed = await self.compressor(result) + if self.compressor.last_call_failed: + logger.warning( + "LLM summary compression failed; falling back " + "to turn-based truncation to bound context " + "size.", + ) + result = self.truncator.truncate_by_turns( + result, + keep_most_recent_turns=self.config.enforce_max_turns, + drop_turns=self.config.truncate_turns, + ) + was_hard_truncated = True + else: + result = compressed + else: + result = self.truncator.truncate_by_turns( + result, + keep_most_recent_turns=self.config.enforce_max_turns, + drop_turns=self.config.truncate_turns, + ) + was_hard_truncated = True - # 2. 基于 token 的压缩 if self.config.max_context_tokens > 0: total_tokens = self.token_counter.count_tokens( result, trusted_token_usage @@ -73,16 +106,19 @@ async def process( if self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens ): - result = await self._run_compression(result, total_tokens) + result, compression_was_lossy = await self._run_compression( + result, total_tokens + ) + was_hard_truncated = was_hard_truncated or compression_was_lossy - return result + return result, was_hard_truncated except Exception as e: logger.error(f"Error during context processing: {e}", exc_info=True) - return messages + return messages, False async def _run_compression( self, messages: list[Message], prev_tokens: int - ) -> list[Message]: + ) -> tuple[list[Message], bool]: """ Compress/truncate the messages. @@ -95,7 +131,12 @@ async def _run_compression( """ logger.debug("Compress triggered, starting compression...") - messages = await self.compressor(messages) + compressed = await self.compressor(messages) + was_lossy = ( + not isinstance(self.compressor, LLMSummaryCompressor) + or self.compressor.last_call_failed + ) + messages = compressed # double check tokens_after_summary = self.token_counter.count_tokens(messages) @@ -113,9 +154,24 @@ async def _run_compression( messages, tokens_after_summary, self.config.max_context_tokens ): logger.info( - "Context still exceeds max tokens after compression, applying halving truncation..." + "Context still exceeds max tokens after compression, applying hard truncation..." ) - # still need compress, truncate by half - messages = self.truncator.truncate_by_halving(messages) - - return messages + was_lossy = True + while self.compressor.should_compress( + messages, tokens_after_summary, self.config.max_context_tokens + ): + truncated = self.truncator.truncate_by_dropping_oldest_turns( + messages, + drop_turns=self.config.truncate_turns, + ) + if truncated == messages: + truncated = self.truncator.truncate_by_halving(messages) + if truncated == messages: + break + next_tokens = self.token_counter.count_tokens(truncated) + if next_tokens >= tokens_after_summary: + break + messages = truncated + tokens_after_summary = next_tokens + + return messages, was_lossy diff --git a/astrbot/core/agent/context/round_utils.py b/astrbot/core/agent/context/round_utils.py index c93057ef44..0f584345ed 100644 --- a/astrbot/core/agent/context/round_utils.py +++ b/astrbot/core/agent/context/round_utils.py @@ -35,6 +35,22 @@ def split_into_rounds( return rounds +def count_conversation_rounds(contexts: Sequence[RoundSegment]) -> int: + """Count logical user conversation rounds. + + Args: + contexts: Flat message contexts. + + Returns: + Number of rounds that contain a user message. + """ + return sum( + 1 + for round_segments in split_into_rounds(contexts) + if any(_segment_role(seg) == "user" for seg in round_segments) + ) + + def _content_to_text(content: Any) -> str: if isinstance(content, list): normalized = [ diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index 9abf574336..f6fb53dd49 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -1,4 +1,5 @@ from ..message import Message +from .round_utils import count_conversation_rounds, split_into_rounds class ContextTruncator: @@ -120,15 +121,20 @@ def truncate_by_turns( return messages system_messages, non_system_messages = self._split_system_rest(messages) + rounds = split_into_rounds(non_system_messages) - if len(non_system_messages) // 2 <= keep_most_recent_turns: + if count_conversation_rounds(non_system_messages) <= keep_most_recent_turns: return messages num_to_keep = keep_most_recent_turns - drop_turns + 1 if num_to_keep <= 0: truncated_contexts = [] else: - truncated_contexts = non_system_messages[-num_to_keep * 2 :] + truncated_contexts = [ + segment + for round_segments in rounds[-num_to_keep:] + for segment in round_segments + ] # Find the first user message index = next( @@ -153,11 +159,16 @@ def truncate_by_dropping_oldest_turns( return messages system_messages, non_system_messages = self._split_system_rest(messages) + rounds = split_into_rounds(non_system_messages) - if len(non_system_messages) // 2 <= drop_turns: + if count_conversation_rounds(non_system_messages) <= drop_turns: truncated_non_system = [] else: - truncated_non_system = non_system_messages[drop_turns * 2 :] + truncated_non_system = [ + segment + for round_segments in rounds[drop_turns:] + for segment in round_segments + ] # Find the first user message index = next( diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 3c500b2d64..d00982f6cd 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -16,6 +16,7 @@ class ContextWrapper(Generic[TContext]): context: TContext messages: list[Message] = Field(default_factory=list) """This field stores the llm message context for the agent run, agent runners will maintain this field automatically.""" + persisted_messages: list[Message] | None = Field(default=None) tool_call_timeout: int = 120 # Default tool call timeout in seconds diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b56d7e62fb..aa1c0d57af 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -193,7 +193,7 @@ async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None parts.append(TextPart(text=llm_resp.completion_text)) if len(parts) == 0: logger.warning("LLM returned empty assistant message with no tool calls.") - self.run_context.messages.append(Message(role="assistant", content=parts)) + self._append_message(Message(role="assistant", content=parts)) try: await self.agent_hooks.on_agent_done(self.run_context, llm_resp) @@ -318,6 +318,7 @@ async def reset( Message(role="system", content=request.system_prompt), ) self.run_context.messages = messages + self.run_context.persisted_messages = list(messages) self.stats = AgentStats() self.stats.start_time = time.time() @@ -327,6 +328,16 @@ def _read_tool_hint(self) -> str: return f"`{self.read_tool.name}`" return "the available file-read tool" + def _append_message(self, message: Message) -> None: + self.run_context.messages.append(message) + if self.run_context.persisted_messages is not None: + self.run_context.persisted_messages.append(message) + + def _extend_messages(self, new_messages: T.Sequence[Message]) -> None: + self.run_context.messages.extend(new_messages) + if self.run_context.persisted_messages is not None: + self.run_context.persisted_messages.extend(new_messages) + async def _assemble_request_context_for_provider( self, request: ProviderRequest, @@ -697,7 +708,8 @@ async def step(self): if not self.req: raise ValueError("Request is not set. Please call reset() first.") - if self._state == AgentState.IDLE: + is_first_step = self._state == AgentState.IDLE + if is_first_step: try: await self.agent_hooks.on_agent_begin(self.run_context) except Exception as e: @@ -707,14 +719,30 @@ async def step(self): self._transition_state(AgentState.RUNNING) llm_resp_result = None - # Process request-time context before sending it to the provider. - token_usage = self.req.conversation.token_usage if self.req.conversation else 0 - self._simple_print_message_role("[BefCompact]", self.run_context.messages) - self.run_context.messages = await self.request_context_manager.process( - self.run_context.messages, trusted_token_usage=token_usage + token_usage = ( + self.req.conversation.token_usage + if is_first_step and self.req.conversation + else 0 + ) + source_messages = ( + self.run_context.persisted_messages or self.run_context.messages + ) + self._simple_print_message_role("[BefCompact]", source_messages) + pre_compaction_count = len(source_messages) + ( + self.run_context.messages, + was_hard_truncated, + ) = await self.request_context_manager.process_with_meta( + source_messages, trusted_token_usage=token_usage ) self._simple_print_message_role("[AftCompact]", self.run_context.messages) + if ( + not was_hard_truncated + and len(self.run_context.messages) != pre_compaction_count + ): + self.run_context.persisted_messages = list(self.run_context.messages) + async for llm_response in self._iter_llm_responses_with_fallback(): if llm_response.is_chunk: if self.stats.time_to_first_token == 0: @@ -903,9 +931,7 @@ async def step(self): tool_calls_result=tool_call_result_blocks, ) # record the assistant message with tool calls - self.run_context.messages.extend( - tool_calls_result.to_openai_messages_model() - ) + self._extend_messages(tool_calls_result.to_openai_messages_model()) # If there are cached images and the model supports image input, # append a user message with images so LLM can see them @@ -937,9 +963,7 @@ async def step(self): ) ) if image_parts: - self.run_context.messages.append( - Message(role="user", content=image_parts) - ) + self._append_message(Message(role="user", content=image_parts)) logger.debug( f"Appended {len(cached_images)} cached image(s) to context for LLM review" ) @@ -965,7 +989,7 @@ async def step_until_done( if self.req: self.req.func_tool = None # 注入提示词 - self.run_context.messages.append( + self._append_message( Message( role="user", content=self.MAX_STEPS_REACHED_PROMPT, @@ -1386,7 +1410,7 @@ async def _finalize_aborted_step( if llm_resp.completion_text: parts.append(TextPart(text=llm_resp.completion_text)) if parts: - self.run_context.messages.append(Message(role="assistant", content=parts)) + self._append_message(Message(role="assistant", content=parts)) try: await self.agent_hooks.on_agent_done(self.run_context, llm_resp) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 0b636b5b2b..fed2504999 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -328,7 +328,10 @@ async def process( event, req, agent_runner.get_final_llm_resp(), - agent_runner.run_context.messages, + agent_runner.run_context.persisted_messages + if agent_runner.run_context.persisted_messages + is not None + else agent_runner.run_context.messages, agent_runner.stats, user_aborted=agent_runner.was_aborted(), ) @@ -403,7 +406,9 @@ async def process( event, req, final_resp, - agent_runner.run_context.messages, + agent_runner.run_context.persisted_messages + if agent_runner.run_context.persisted_messages is not None + else agent_runner.run_context.messages, agent_runner.stats, user_aborted=agent_runner.was_aborted(), ) diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index dad5a53a25..6b043d31ca 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -251,7 +251,7 @@ "provider_settings": { "max_context_length": { "description": "Max Turns Before Compression", - "hint": "Persistent conversation history is truncated or LLM-compressed by the strategy below only after it exceeds this many turns. Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit." + "hint": "When persistent conversation history exceeds this many turns, it is handled by the 'Handling for History Limits or Context Window Pressure' strategy below (truncation or LLM compression — see that setting for details). Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit." }, "dequeue_context_length": { "description": "Turns to Discard When Limit Exceeded", @@ -263,7 +263,7 @@ "Truncate by Turns", "Compress by LLM" ], - "hint": "Persistent conversation history uses this strategy only after exceeding 'Max Turns Before Compression'. Before each request, the same strategy may also protect the in-flight context when tokens approach the model window." + "hint": "This strategy applies whenever either limit is hit: 'Max Turns Before Compression' is exceeded, or the in-flight request's tokens approach the model's context window. Choosing 'Compress by LLM' first attempts to generate a summary that preserves key points; if compression fails (provider unavailable, call error, etc.), it automatically falls back to 'Truncate by Turns' so the context size stays bounded. Either way, this only affects the working set sent to the model for the current request — already-persisted conversation history in the database is never overwritten or lost as a result of this step." }, "llm_compress_instruction": { "description": "Context Compression Instruction", diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index a5efc78335..d9c4b67b0a 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -251,7 +251,7 @@ "provider_settings": { "max_context_length": { "description": "Макс. раундов перед сжатием", - "hint": "Постоянная история диалога обрезается или сжимается LLM по стратегии ниже только после превышения этого числа раундов. Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам." + "hint": "Когда постоянная история диалога превышает это число раундов, она обрабатывается согласно стратегии «Действие при лимите истории или давлении окна контекста» ниже (обрезка или сжатие LLM — см. описание этого параметра). Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам." }, "dequeue_context_length": { "description": "Раундов для удаления при превышении лимита", @@ -263,7 +263,7 @@ "Обрезать по раундам", "Сжать с помощью LLM" ], - "hint": "Постоянная история диалога использует эту стратегию только после превышения лимита раундов. Перед каждым запросом та же стратегия может защищать текущий контекст, когда токены приближаются к окну модели." + "hint": "Эта стратегия применяется при превышении любого из двух лимитов: «Макс. раундов перед сжатием» или приближении токенов запроса к окну контекста модели. При выборе «Сжать с помощью LLM» сначала делается попытка создать сводку с сохранением ключевых моментов; если сжатие не удалось (провайдер недоступен, ошибка вызова и т.д.), происходит автоматический откат к «Обрезать по раундам», чтобы ограничить размер контекста. В любом случае этот шаг влияет только на рабочий набор, отправляемый модели для текущего запроса — уже сохранённая история диалога в базе данных никогда не перезаписывается и не теряется в результате этого шага." }, "llm_compress_instruction": { "description": "Инструкция для сжатия контекста", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index bcfb4e20dc..366dfceaa9 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -253,7 +253,7 @@ "provider_settings": { "max_context_length": { "description": "压缩前最多保留对话轮数", - "hint": "普通会话历史超过该轮数后,才会按下方策略进行持久化截断或 LLM 压缩;请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。" + "hint": "普通会话历史超过该轮数后,会按下方\"历史超限或上下文接近上限时的处理方式\"进行处理(截断或 LLM 压缩,详见该项说明);请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。" }, "dequeue_context_length": { "description": "轮次超限时一次丢弃轮数", @@ -265,7 +265,7 @@ "按对话轮数截断", "由 LLM 压缩上下文" ], - "hint": "普通会话历史仅在超过\"压缩前最多保留对话轮数\"后执行该策略;请求发送前也会在上下文 token 接近模型窗口时使用同一策略保护本次请求。" + "hint": "无论是\"压缩前最多保留对话轮数\"超限,还是请求发送前上下文 token 接近模型窗口,都会按此策略处理:选择\"由 LLM 压缩上下文\"时,会先尝试生成摘要保留要点;若压缩失败(模型不可用、调用异常等),将自动回退为\"按对话轮数截断\",以保证上下文不会无限增长。无论走哪条路径,本次处理只影响\"发给模型的这一次请求\",已经持久化到对话历史(数据库)中的记录不会被这一步的裁剪/压缩结果覆盖或丢失。" }, "llm_compress_instruction": { "description": "上下文压缩提示词", diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 8e9e601b3f..4ad800b249 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -42,6 +42,26 @@ def meta(self): return MagicMock(id="test_provider", type="openai") +class MessageCountTokenCounter: + """Token counter that assigns a fixed cost to each message.""" + + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + """Count tokens by message count for deterministic tests. + + Args: + messages: The messages to count. + trusted_token_usage: A trusted token count to return when present. + + Returns: + The deterministic token count. + """ + if trusted_token_usage > 0: + return trusted_token_usage + return len(messages) * 100 + + class TestContextManager: """Test suite for ContextManager.""" @@ -467,6 +487,32 @@ async def test_enforce_max_turns_with_system_messages(self): assert len(system_msgs) >= 1 assert system_msgs[0].content == "System instruction" + @pytest.mark.asyncio + async def test_enforce_max_turns_counts_tool_chain_as_one_round(self): + """Tool messages in one round should not inflate turn count.""" + config = ContextConfig(enforce_max_turns=1, truncate_turns=1) + manager = ContextManager(config) + messages = [ + self.create_message("user", "Run a tool"), + Message( + role="assistant", + content="Calling tool", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="Tool result", tool_call_id="call_1"), + self.create_message("assistant", "Done"), + ] + + result = await manager.process(messages) + + assert result == messages + # ==================== Token-based Compression Tests ==================== @pytest.mark.asyncio @@ -904,11 +950,15 @@ async def test_run_compression_calls_compressor(self): mock_compressor.should_compress = MagicMock(return_value=False) manager.compressor = mock_compressor - result = await manager._run_compression(messages, prev_tokens=100) + result, was_lossy = await manager._run_compression(messages, prev_tokens=100) # Compressor __call__ should be invoked mock_compressor.assert_called_once_with(messages) assert result == compressed + # mock_compressor is not an LLMSummaryCompressor, so this is treated + # as a lossy (hard) truncation, same as the built-in + # TruncateByTurnsCompressor. + assert was_lossy is True @pytest.mark.asyncio async def test_run_compression_applies_compressor_through_process(self): @@ -966,6 +1016,28 @@ async def test_llm_compression_with_mock_provider(self): # Should have been compressed assert len(result) <= len(messages) + @pytest.mark.asyncio + async def test_llm_failure_falls_back_until_token_threshold(self): + """Failed LLM compression should hard truncate until tokens are acceptable.""" + mock_provider = MockProvider() + mock_provider.text_chat = AsyncMock( + return_value=LLMResponse(role="err", completion_text="compress failed") + ) + config = ContextConfig( + max_context_tokens=300, + truncate_turns=1, + llm_compress_provider=mock_provider, # type: ignore[arg-type] + custom_token_counter=MessageCountTokenCounter(), + ) + manager = ContextManager(config) + messages = self.create_messages(10) + + result, was_hard_truncated = await manager.process_with_meta(messages) + + assert was_hard_truncated is True + assert len(result) == 2 + assert manager.token_counter.count_tokens(result) <= 246 + # ==================== split_into_rounds Tests ==================== def test_split_rounds_ensures_user_start(self): diff --git a/tests/agent/test_truncator.py b/tests/agent/test_truncator.py index 7dac80f9ce..e11d29f241 100644 --- a/tests/agent/test_truncator.py +++ b/tests/agent/test_truncator.py @@ -134,6 +134,32 @@ def test_truncate_by_turns_exact_threshold(self): assert len(result) == 6 assert result == messages + def test_truncate_by_turns_counts_tool_chain_as_one_round(self): + """Tool calls/results inside one round should not count as extra turns.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Run a tool"), + Message( + role="assistant", + content="Calling tool", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="Tool result", tool_call_id="call_1"), + self.create_message("assistant", "Done"), + ] + + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=1, drop_turns=1 + ) + + assert result == messages + def test_truncate_by_turns_ensures_user_first(self): """Test that truncate_by_turns ensures user message comes first.""" truncator = ContextTruncator()