Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
self.compression_threshold = compression_threshold
self.token_counter = token_counter or EstimateTokenCounter()
self.last_call_failed = False

self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
Expand Down Expand Up @@ -212,6 +213,8 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
"""
from .round_utils import split_into_rounds

self.last_call_failed = False

rounds = split_into_rounds(messages)
message_rounds = [
[seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds
Expand Down Expand Up @@ -276,13 +279,19 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
response = await self.provider.text_chat(
contexts=sanitized_summary_contexts,
)
if response.role == "err":
logger.error(f"Failed to generate summary: {response.completion_text}")
self.last_call_failed = True
return messages
summary_content = (response.completion_text or "").strip()
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
self.last_call_failed = True
return messages

if not summary_content:
Comment thread
lingyun14beta marked this conversation as resolved.
logger.warning("LLM context compression returned an empty summary.")
self.last_call_failed = True
return messages

# Build result: system messages + summary pair + recent rounds
Expand Down
67 changes: 54 additions & 13 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,50 @@ async def process(
Returns:
The processed message list.
"""
result, _ = await self.process_with_meta(messages, trusted_token_usage)
return result

async def process_with_meta(
self, messages: list[Message], trusted_token_usage: int = 0
) -> tuple[list[Message], bool]:
try:
result = messages
was_hard_truncated = False

# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
non_system_count = len([m for m in result if m.role != "system"]) // 2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Turn counting via len(non_system_msgs) // 2 is brittle when roles are not strictly user/assistant pairs.

This assumes every non-system message forms a strict user/assistant pair and that other roles (tool, function, etc.) either don’t appear or fit that pattern. With odd message counts or extra roles, turn counting can be wrong, leading to premature or delayed enforce_max_turns behavior. It would be safer to derive the turn count using the same logic as your truncator (or a shared split_into_rounds-style helper) so enforcement matches how truncation is actually applied.

Suggested implementation:

            result = messages
            was_hard_truncated = False

            if self.config.enforce_max_turns != -1:
                # Count turns in a way that is robust to extra roles (tool/function/etc.)
                # A "turn" starts with a user message and continues through the
                # assistant/tool/function responses until the next user message.
                non_system_messages = [m for m in result if m.role != "system"]
                turns = 0
                in_turn = False
                for message in non_system_messages:
                    if message.role == "user":
                        if not in_turn:
                            turns += 1
                            in_turn = True
                    elif message.role in ("assistant", "tool", "function"):
                        # These are considered responses within the current turn.
                        if in_turn:
                            in_turn = False

                if turns > self.config.enforce_max_turns:
                    if isinstance(self.compressor, LLMSummaryCompressor):
                        logger.debug(
                            "Turn limit (%s) exceeded (%s turns), "
                            "delegating to LLM summary compressor instead of "
                            "hard truncation.",
                            self.config.enforce_max_turns,
                            turns,
                        )
                        compressed = await self.compressor(result)
                        if self.compressor.last_call_failed:
                            logger.warning(

To fully align enforce_max_turns behavior with actual truncation:

  1. Extract the turn-counting logic here into a shared helper function (e.g. _count_turns(messages: Sequence[Message]) -> int) on the same class or in a shared module.
  2. Update the truncation logic (wherever self.config.truncate_turns is applied) to call this helper instead of duplicating or diverging in its own turn-counting approach.
  3. If your truncator already has its own notion of "rounds" or "turns" (e.g. via a split_into_rounds helper), prefer moving that logic into the shared helper and call it both from the truncator and from this enforcement block.
  4. Ensure any additional roles you use in your system (beyond user, assistant, tool, function) are appropriately classified as "turn-starting" (like user) or "response" roles inside a turn so counting is consistent.

if non_system_count > self.config.enforce_max_turns:
if isinstance(self.compressor, LLMSummaryCompressor):
logger.debug(
"Turn limit (%s) exceeded (%s turns), "
"delegating to LLM summary compressor instead of "
"hard truncation.",
self.config.enforce_max_turns,
non_system_count,
)
compressed = await self.compressor(result)
if self.compressor.last_call_failed:
logger.warning(
"LLM summary compression failed; falling back "
"to turn-based truncation to bound context "
"size.",
)
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
was_hard_truncated = True
else:
result = compressed
else:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
was_hard_truncated = True

# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
Expand All @@ -73,16 +105,19 @@ async def process(
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
result, compression_was_lossy = await self._run_compression(
result, total_tokens
)
was_hard_truncated = was_hard_truncated or compression_was_lossy

return result
return result, was_hard_truncated
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
return messages, False

async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
) -> tuple[list[Message], bool]:
"""
Compress/truncate the messages.

Expand All @@ -95,7 +130,12 @@ async def _run_compression(
"""
logger.debug("Compress triggered, starting compression...")

messages = await self.compressor(messages)
compressed = await self.compressor(messages)
was_lossy = (
not isinstance(self.compressor, LLMSummaryCompressor)
or self.compressor.last_call_failed
)
Comment thread
lingyun14beta marked this conversation as resolved.
messages = compressed

# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
Expand All @@ -117,5 +157,6 @@ async def _run_compression(
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
was_lossy = True

return messages
return messages, was_lossy
1 change: 1 addition & 0 deletions astrbot/core/agent/run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ContextWrapper(Generic[TContext]):
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
persisted_messages: list[Message] | None = Field(default=None)
tool_call_timeout: int = 120 # Default tool call timeout in seconds


Expand Down
54 changes: 37 additions & 17 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
logger.warning("LLM returned empty assistant message with no tool calls.")
self.run_context.messages.append(Message(role="assistant", content=parts))
self._append_message(Message(role="assistant", content=parts))

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
Expand Down Expand Up @@ -318,6 +318,7 @@ async def reset(
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
self.run_context.persisted_messages = list(messages)

self.stats = AgentStats()
self.stats.start_time = time.time()
Expand All @@ -327,6 +328,16 @@ def _read_tool_hint(self) -> str:
return f"`{self.read_tool.name}`"
return "the available file-read tool"

def _append_message(self, message: Message) -> None:
self.run_context.messages.append(message)
if self.run_context.persisted_messages is not None:
self.run_context.persisted_messages.append(message)

def _extend_messages(self, new_messages: T.Sequence[Message]) -> None:
self.run_context.messages.extend(new_messages)
if self.run_context.persisted_messages is not None:
self.run_context.persisted_messages.extend(new_messages)

async def _assemble_request_context_for_provider(
self,
request: ProviderRequest,
Expand Down Expand Up @@ -697,7 +708,8 @@ async def step(self):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")

if self._state == AgentState.IDLE:
is_first_step = self._state == AgentState.IDLE
if is_first_step:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
Expand All @@ -707,13 +719,25 @@ async def step(self):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None

# Process request-time context before sending it to the provider.
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]", self.run_context.messages)
self.run_context.messages = await self.request_context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]", self.run_context.messages)
if is_first_step:
token_usage = (
self.req.conversation.token_usage if self.req.conversation else 0
)
self._simple_print_message_role("[BefCompact]", self.run_context.messages)
pre_compaction_count = len(self.run_context.messages)
(
self.run_context.messages,
was_hard_truncated,
) = await self.request_context_manager.process_with_meta(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]", self.run_context.messages)

if (
not was_hard_truncated
and len(self.run_context.messages) != pre_compaction_count
):
self.run_context.persisted_messages = list(self.run_context.messages)

async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
Expand Down Expand Up @@ -903,9 +927,7 @@ async def step(self):
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self._extend_messages(tool_calls_result.to_openai_messages_model())

# If there are cached images and the model supports image input,
# append a user message with images so LLM can see them
Expand Down Expand Up @@ -937,9 +959,7 @@ async def step(self):
)
)
if image_parts:
self.run_context.messages.append(
Message(role="user", content=image_parts)
)
self._append_message(Message(role="user", content=image_parts))
logger.debug(
f"Appended {len(cached_images)} cached image(s) to context for LLM review"
)
Expand All @@ -965,7 +985,7 @@ async def step_until_done(
if self.req:
self.req.func_tool = None
# 注入提示词
self.run_context.messages.append(
self._append_message(
Message(
role="user",
content=self.MAX_STEPS_REACHED_PROMPT,
Expand Down Expand Up @@ -1386,7 +1406,7 @@ async def _finalize_aborted_step(
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if parts:
self.run_context.messages.append(Message(role="assistant", content=parts))
self._append_message(Message(role="assistant", content=parts))

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ async def process(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.run_context.persisted_messages
if agent_runner.run_context.persisted_messages
is not None
else agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
Expand Down Expand Up @@ -403,7 +406,9 @@ async def process(
event,
req,
final_resp,
agent_runner.run_context.messages,
agent_runner.run_context.persisted_messages
if agent_runner.run_context.persisted_messages is not None
else agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
"provider_settings": {
"max_context_length": {
"description": "Max Turns Before Compression",
"hint": "Persistent conversation history is truncated or LLM-compressed by the strategy below only after it exceeds this many turns. Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit."
"hint": "When persistent conversation history exceeds this many turns, it is handled by the 'Handling for History Limits or Context Window Pressure' strategy below (truncation or LLM compression — see that setting for details). Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit."
},
"dequeue_context_length": {
"description": "Turns to Discard When Limit Exceeded",
Expand All @@ -263,7 +263,7 @@
"Truncate by Turns",
"Compress by LLM"
],
"hint": "Persistent conversation history uses this strategy only after exceeding 'Max Turns Before Compression'. Before each request, the same strategy may also protect the in-flight context when tokens approach the model window."
"hint": "This strategy applies whenever either limit is hit: 'Max Turns Before Compression' is exceeded, or the in-flight request's tokens approach the model's context window. Choosing 'Compress by LLM' first attempts to generate a summary that preserves key points; if compression fails (provider unavailable, call error, etc.), it automatically falls back to 'Truncate by Turns' so the context size stays bounded. Either way, this only affects the working set sent to the model for the current request — already-persisted conversation history in the database is never overwritten or lost as a result of this step."
},
"llm_compress_instruction": {
"description": "Context Compression Instruction",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
"provider_settings": {
"max_context_length": {
"description": "Макс. раундов перед сжатием",
"hint": "Постоянная история диалога обрезается или сжимается LLM по стратегии ниже только после превышения этого числа раундов. Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам."
"hint": "Когда постоянная история диалога превышает это число раундов, она обрабатывается согласно стратегии «Действие при лимите истории или давлении окна контекста» ниже (обрезка или сжатие LLM — см. описание этого параметра). Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам."
},
"dequeue_context_length": {
"description": "Раундов для удаления при превышении лимита",
Expand All @@ -263,7 +263,7 @@
"Обрезать по раундам",
"Сжать с помощью LLM"
],
"hint": "Постоянная история диалога использует эту стратегию только после превышения лимита раундов. Перед каждым запросом та же стратегия может защищать текущий контекст, когда токены приближаются к окну модели."
"hint": "Эта стратегия применяется при превышении любого из двух лимитов: «Макс. раундов перед сжатием» или приближении токенов запроса к окну контекста модели. При выборе «Сжать с помощью LLM» сначала делается попытка создать сводку с сохранением ключевых моментов; если сжатие не удалось (провайдер недоступен, ошибка вызова и т.д.), происходит автоматический откат к «Обрезать по раундам», чтобы ограничить размер контекста. В любом случае этот шаг влияет только на рабочий набор, отправляемый модели для текущего запроса — уже сохранённая история диалога в базе данных никогда не перезаписывается и не теряется в результате этого шага."
},
"llm_compress_instruction": {
"description": "Инструкция для сжатия контекста",
Expand Down
Loading
Loading