Skip to content
9 changes: 9 additions & 0 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
self.compression_threshold = compression_threshold
self.token_counter = token_counter or EstimateTokenCounter()
self.last_call_failed = False

self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
Expand Down Expand Up @@ -212,6 +213,8 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
"""
from .round_utils import split_into_rounds

self.last_call_failed = False

rounds = split_into_rounds(messages)
message_rounds = [
[seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds
Expand Down Expand Up @@ -276,13 +279,19 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
response = await self.provider.text_chat(
contexts=sanitized_summary_contexts,
)
if response.role == "err":
logger.error(f"Failed to generate summary: {response.completion_text}")
self.last_call_failed = True
return messages
summary_content = (response.completion_text or "").strip()
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
self.last_call_failed = True
return messages

if not summary_content:
logger.warning("LLM context compression returned an empty summary.")
self.last_call_failed = True
return messages

# Build result: system messages + summary pair + recent rounds
Expand Down
90 changes: 73 additions & 17 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .round_utils import count_conversation_rounds
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator

Expand Down Expand Up @@ -53,18 +54,50 @@ async def process(
Returns:
The processed message list.
"""
result, _ = await self.process_with_meta(messages, trusted_token_usage)
return result

async def process_with_meta(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the new control flow into a single internal implementation with small helpers for turn limits and truncation while centralizing lossy/flag handling at the top level.

You can keep the new behavior but reduce the branching/flag complexity by:

  1. Centralizing the main flow in a single internal method
  2. Extracting turn-limit and truncation logic into small helpers
  3. Keeping lossy flags managed at the top level, not sprinkled in loops

1. Unify process / process_with_meta via a small result object

Instead of two public entry points with tuple plumbing, move the core logic into a single internal method returning a small result object. process and process_with_meta just adapt that.

from dataclasses import dataclass

@dataclass
class ProcessResult:
    messages: list[Message]
    was_hard_truncated: bool

async def process(self, messages: list[Message], trusted_token_usage: int = 0) -> list[Message]:
    result = await self._process_impl(messages, trusted_token_usage)
    return result.messages

async def process_with_meta(
    self,
    messages: list[Message],
    trusted_token_usage: int = 0,
) -> tuple[list[Message], bool]:
    result = await self._process_impl(messages, trusted_token_usage)
    return result.messages, result.was_hard_truncated

That keeps the behavior identical, but there’s only one place (_process_impl) where control flow and flags are managed.

2. Extract turn-limit logic into a helper

This pulls the policy (LLM vs truncation, last_call_failed) out of the main flow and returns both the messages and whether we hard truncated.

async def _apply_turn_limit(
    self, messages: list[Message]
) -> ProcessResult:
    result = ProcessResult(messages=messages, was_hard_truncated=False)

    if self.config.enforce_max_turns == -1:
        return result

    turn_count = count_conversation_rounds(result.messages)
    if turn_count <= self.config.enforce_max_turns:
        return result

    if isinstance(self.compressor, LLMSummaryCompressor):
        logger.debug(
            "Turn limit (%s) exceeded (%s turns), delegating to LLM summary compressor.",
            self.config.enforce_max_turns,
            turn_count,
        )
        compressed = await self.compressor(result.messages)
        if self.compressor.last_call_failed:
            logger.warning(
                "LLM summary compression failed; falling back to turn-based truncation."
            )
            result.messages = self.truncator.truncate_by_turns(
                result.messages,
                keep_most_recent_turns=self.config.enforce_max_turns,
                drop_turns=self.config.truncate_turns,
            )
            result.was_hard_truncated = True
        else:
            result.messages = compressed
    else:
        result.messages = self.truncator.truncate_by_turns(
            result.messages,
            keep_most_recent_turns=self.config.enforce_max_turns,
            drop_turns=self.config.truncate_turns,
        )
        result.was_hard_truncated = True

    return result

Then _process_impl just wires this together:

async def _process_impl(
    self, messages: list[Message], trusted_token_usage: int = 0
) -> ProcessResult:
    try:
        result = await self._apply_turn_limit(messages)

        if self.config.max_context_tokens > 0:
            total_tokens = self.token_counter.count_tokens(
                result.messages, trusted_token_usage
            )
            if self.compressor.should_compress(
                result.messages, total_tokens, self.config.max_context_tokens
            ):
                compressed, was_lossy = await self._run_compression(
                    result.messages, total_tokens
                )
                result.messages = compressed
                result.was_hard_truncated |= was_lossy

        return result
    except Exception as e:
        logger.error(f"Error during context processing: {e}", exc_info=True)
        return ProcessResult(messages=messages, was_hard_truncated=False)

This flattens process_with_meta and keeps policy logic isolated.

3. Extract truncation loop into a helper and keep was_lossy top-level

You can keep _run_compression focused on “run compressor + decide if truncation is needed”, and move the truncation loop into a small helper. Also, set was_lossy once when you decide to truncate, instead of toggling inside the loop.

async def _run_compression(
    self, messages: list[Message], prev_tokens: int
) -> tuple[list[Message], bool]:
    logger.debug("Compress triggered, starting compression...")
    compressed = await self.compressor(messages)

    was_lossy = (
        not isinstance(self.compressor, LLMSummaryCompressor)
        or self.compressor.last_call_failed
    )

    messages = compressed
    tokens_after_summary = self.token_counter.count_tokens(messages)

    compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
    logger.info(
        f"Compress completed."
        f" {prev_tokens} -> {tokens_after_summary} tokens,"
        f" compression rate: {compress_rate:.2f}%.",
    )

    if not self.compressor.should_compress(
        messages, tokens_after_summary, self.config.max_context_tokens
    ):
        return messages, was_lossy

    logger.info(
        "Context still exceeds max tokens after compression, applying hard truncation..."
    )
    # any truncation is lossy by definition
    was_lossy = True
    messages = self._progressive_truncate_until_within_limit(messages, tokens_after_summary)
    return messages, was_lossy

And then the truncation loop becomes a self-contained helper with clear guards:

def _progressive_truncate_until_within_limit(
    self, messages: list[Message], current_tokens: int
) -> list[Message]:
    while self.compressor.should_compress(
        messages, current_tokens, self.config.max_context_tokens
    ):
        truncated = self.truncator.truncate_by_dropping_oldest_turns(
            messages,
            drop_turns=self.config.truncate_turns,
        )

        if truncated == messages:
            truncated = self.truncator.truncate_by_halving(messages)

        if truncated == messages:
            break

        next_tokens = self.token_counter.count_tokens(truncated)
        if next_tokens >= current_tokens:
            break

        messages = truncated
        current_tokens = next_tokens

    return messages

Behavior stays the same, but:

  • _run_compression is now a linear sequence with one loop delegated to a named helper.
  • was_lossy is only managed in _run_compression at two points: post-LLM decision and when truncation kicks in.
  • The truncation-specific concerns (no-op checks, non-decreasing tokens, strategy choice) are isolated in _progressive_truncate_until_within_limit, which can be tested independently.

self, messages: list[Message], trusted_token_usage: int = 0
) -> tuple[list[Message], bool]:
try:
result = messages
was_hard_truncated = False

# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
turn_count = count_conversation_rounds(result)
if turn_count > self.config.enforce_max_turns:
if isinstance(self.compressor, LLMSummaryCompressor):
logger.debug(
"Turn limit (%s) exceeded (%s turns), "
"delegating to LLM summary compressor instead of "
"hard truncation.",
self.config.enforce_max_turns,
turn_count,
)
compressed = await self.compressor(result)
if self.compressor.last_call_failed:
logger.warning(
"LLM summary compression failed; falling back "
"to turn-based truncation to bound context "
"size.",
)
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
was_hard_truncated = True
else:
result = compressed
else:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
was_hard_truncated = True
Comment on lines +60 to +99

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When both enforce_max_turns and max_context_tokens are exceeded, the context manager will trigger LLM compression twice in a single processing run. This is highly inefficient and increases API costs. We should track if LLM compression has already been performed during the turn-limit check and pass this state to avoid redundant compression.

    async def process_with_meta(
        self, messages: list[Message], trusted_token_usage: int = 0
    ) -> tuple[list[Message], bool]:
        try:
            result = messages
            was_hard_truncated = False
            already_compressed = False

            if self.config.enforce_max_turns != -1:
                turn_count = count_conversation_rounds(result)
                if turn_count > self.config.enforce_max_turns:
                    if isinstance(self.compressor, LLMSummaryCompressor):
                        logger.debug(
                            "Turn limit (%s) exceeded (%s turns), "
                            "delegating to LLM summary compressor instead of "
                            "hard truncation.",
                            self.config.enforce_max_turns,
                            turn_count,
                        )
                        compressed = await self.compressor(result)
                        if self.compressor.last_call_failed:
                            logger.warning(
                                "LLM summary compression failed; falling back "
                                "to turn-based truncation to bound context "
                                "size.",
                            )
                            result = self.truncator.truncate_by_turns(
                                result,
                                keep_most_recent_turns=self.config.enforce_max_turns,
                                drop_turns=self.config.truncate_turns,
                            )
                            was_hard_truncated = True
                        else:
                            result = compressed
                            already_compressed = True
                    else:
                        result = self.truncator.truncate_by_turns(
                            result,
                            keep_most_recent_turns=self.config.enforce_max_turns,
                            drop_turns=self.config.truncate_turns,
                        )
                        was_hard_truncated = True


# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
Expand All @@ -73,16 +106,19 @@ async def process(
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
result, compression_was_lossy = await self._run_compression(
result, total_tokens
)
was_hard_truncated = was_hard_truncated or compression_was_lossy
Comment on lines 106 to +112

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pass the already_compressed flag to _run_compression to prevent duplicate LLM compression calls.

Suggested change
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
result, compression_was_lossy = await self._run_compression(
result, total_tokens
)
was_hard_truncated = was_hard_truncated or compression_was_lossy
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result, compression_was_lossy = await self._run_compression(
result, total_tokens, force_hard_truncation=already_compressed
)
was_hard_truncated = was_hard_truncated or compression_was_lossy


return result
return result, was_hard_truncated
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
return messages, False

async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
) -> tuple[list[Message], bool]:
Comment on lines 119 to +121

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the signature of _run_compression to accept force_hard_truncation.

Suggested change
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
) -> tuple[list[Message], bool]:
async def _run_compression(
self, messages: list[Message], prev_tokens: int, force_hard_truncation: bool = False
) -> tuple[list[Message], bool]:

"""
Compress/truncate the messages.

Expand All @@ -95,7 +131,12 @@ async def _run_compression(
"""
logger.debug("Compress triggered, starting compression...")

messages = await self.compressor(messages)
compressed = await self.compressor(messages)
was_lossy = (
not isinstance(self.compressor, LLMSummaryCompressor)
or self.compressor.last_call_failed
)
messages = compressed
Comment on lines 132 to +139

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Skip LLM compression and directly mark as lossy if force_hard_truncation is enabled.

Suggested change
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
compressed = await self.compressor(messages)
was_lossy = (
not isinstance(self.compressor, LLMSummaryCompressor)
or self.compressor.last_call_failed
)
messages = compressed
if not force_hard_truncation:
logger.debug("Compress triggered, starting compression...")
compressed = await self.compressor(messages)
was_lossy = (
not isinstance(self.compressor, LLMSummaryCompressor)
or self.compressor.last_call_failed
)
messages = compressed
else:
was_lossy = True


# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
Expand All @@ -113,9 +154,24 @@ async def _run_compression(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
"Context still exceeds max tokens after compression, applying hard truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)

return messages
was_lossy = True
while self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
truncated = self.truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.config.truncate_turns,
)
if truncated == messages:
truncated = self.truncator.truncate_by_halving(messages)
if truncated == messages:
break
next_tokens = self.token_counter.count_tokens(truncated)
if next_tokens >= tokens_after_summary:
break
messages = truncated
tokens_after_summary = next_tokens

return messages, was_lossy
16 changes: 16 additions & 0 deletions astrbot/core/agent/context/round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def split_into_rounds(
return rounds


def count_conversation_rounds(contexts: Sequence[RoundSegment]) -> int:
"""Count logical user conversation rounds.

Args:
contexts: Flat message contexts.

Returns:
Number of rounds that contain a user message.
"""
return sum(
1
for round_segments in split_into_rounds(contexts)
if any(_segment_role(seg) == "user" for seg in round_segments)
)


def _content_to_text(content: Any) -> str:
if isinstance(content, list):
normalized = [
Expand Down
19 changes: 15 additions & 4 deletions astrbot/core/agent/context/truncator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..message import Message
from .round_utils import count_conversation_rounds, split_into_rounds


class ContextTruncator:
Expand Down Expand Up @@ -120,15 +121,20 @@ def truncate_by_turns(
return messages

system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)

if len(non_system_messages) // 2 <= keep_most_recent_turns:
if count_conversation_rounds(non_system_messages) <= keep_most_recent_turns:
return messages
Comment on lines +124 to 127

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling count_conversation_rounds immediately after split_into_rounds is redundant and inefficient because both functions internally call split_into_rounds. We can calculate the user rounds count directly from the already computed rounds list.

Suggested change
rounds = split_into_rounds(non_system_messages)
if len(non_system_messages) // 2 <= keep_most_recent_turns:
if count_conversation_rounds(non_system_messages) <= keep_most_recent_turns:
return messages
rounds = split_into_rounds(non_system_messages)
user_rounds_count = sum(1 for rnd in rounds if any(seg.role == "user" for seg in rnd))
if user_rounds_count <= keep_most_recent_turns:
return messages


num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
truncated_contexts = [
segment
for round_segments in rounds[-num_to_keep:]
for segment in round_segments
]

# Find the first user message
index = next(
Expand All @@ -153,11 +159,16 @@ def truncate_by_dropping_oldest_turns(
return messages

system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)

if len(non_system_messages) // 2 <= drop_turns:
if count_conversation_rounds(non_system_messages) <= drop_turns:
truncated_non_system = []
Comment on lines +162 to 165

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid redundant split_into_rounds calls by calculating the user rounds count directly from the already computed rounds list.

Suggested change
rounds = split_into_rounds(non_system_messages)
if len(non_system_messages) // 2 <= drop_turns:
if count_conversation_rounds(non_system_messages) <= drop_turns:
truncated_non_system = []
rounds = split_into_rounds(non_system_messages)
user_rounds_count = sum(1 for rnd in rounds if any(seg.role == "user" for seg in rnd))
if user_rounds_count <= drop_turns:
truncated_non_system = []

else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
truncated_non_system = [
segment
for round_segments in rounds[drop_turns:]
for segment in round_segments
]

# Find the first user message
index = next(
Expand Down
1 change: 1 addition & 0 deletions astrbot/core/agent/run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ContextWrapper(Generic[TContext]):
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
persisted_messages: list[Message] | None = Field(default=None)
tool_call_timeout: int = 120 # Default tool call timeout in seconds


Expand Down
54 changes: 39 additions & 15 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
logger.warning("LLM returned empty assistant message with no tool calls.")
self.run_context.messages.append(Message(role="assistant", content=parts))
self._append_message(Message(role="assistant", content=parts))

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
Expand Down Expand Up @@ -318,6 +318,7 @@ async def reset(
Message(role="system", content=request.system_prompt),
)
self.run_context.messages = messages
self.run_context.persisted_messages = list(messages)

self.stats = AgentStats()
self.stats.start_time = time.time()
Expand All @@ -327,6 +328,16 @@ def _read_tool_hint(self) -> str:
return f"`{self.read_tool.name}`"
return "the available file-read tool"

def _append_message(self, message: Message) -> None:
self.run_context.messages.append(message)
if self.run_context.persisted_messages is not None:
self.run_context.persisted_messages.append(message)

def _extend_messages(self, new_messages: T.Sequence[Message]) -> None:
self.run_context.messages.extend(new_messages)
if self.run_context.persisted_messages is not None:
self.run_context.persisted_messages.extend(new_messages)

async def _assemble_request_context_for_provider(
self,
request: ProviderRequest,
Expand Down Expand Up @@ -697,7 +708,8 @@ async def step(self):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")

if self._state == AgentState.IDLE:
is_first_step = self._state == AgentState.IDLE
if is_first_step:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
Expand All @@ -707,14 +719,30 @@ async def step(self):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None

# Process request-time context before sending it to the provider.
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]", self.run_context.messages)
self.run_context.messages = await self.request_context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
token_usage = (
self.req.conversation.token_usage
if is_first_step and self.req.conversation
else 0
)
source_messages = (
self.run_context.persisted_messages or self.run_context.messages
)
self._simple_print_message_role("[BefCompact]", source_messages)
pre_compaction_count = len(source_messages)
(
self.run_context.messages,
was_hard_truncated,
) = await self.request_context_manager.process_with_meta(
source_messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]", self.run_context.messages)

if (
not was_hard_truncated
and len(self.run_context.messages) != pre_compaction_count
):
self.run_context.persisted_messages = list(self.run_context.messages)

async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
if self.stats.time_to_first_token == 0:
Expand Down Expand Up @@ -903,9 +931,7 @@ async def step(self):
tool_calls_result=tool_call_result_blocks,
)
# record the assistant message with tool calls
self.run_context.messages.extend(
tool_calls_result.to_openai_messages_model()
)
self._extend_messages(tool_calls_result.to_openai_messages_model())

# If there are cached images and the model supports image input,
# append a user message with images so LLM can see them
Expand Down Expand Up @@ -937,9 +963,7 @@ async def step(self):
)
)
if image_parts:
self.run_context.messages.append(
Message(role="user", content=image_parts)
)
self._append_message(Message(role="user", content=image_parts))
logger.debug(
f"Appended {len(cached_images)} cached image(s) to context for LLM review"
)
Expand All @@ -965,7 +989,7 @@ async def step_until_done(
if self.req:
self.req.func_tool = None
# 注入提示词
self.run_context.messages.append(
self._append_message(
Message(
role="user",
content=self.MAX_STEPS_REACHED_PROMPT,
Expand Down Expand Up @@ -1386,7 +1410,7 @@ async def _finalize_aborted_step(
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if parts:
self.run_context.messages.append(Message(role="assistant", content=parts))
self._append_message(Message(role="assistant", content=parts))

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ async def process(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.run_context.persisted_messages
if agent_runner.run_context.persisted_messages
is not None
else agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
Expand Down Expand Up @@ -403,7 +406,9 @@ async def process(
event,
req,
final_resp,
agent_runner.run_context.messages,
agent_runner.run_context.persisted_messages
if agent_runner.run_context.persisted_messages is not None
else agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
Expand Down
Loading
Loading