-
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
fix: 避免请求时上下文裁剪覆盖持久化对话历史 #8974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
fix: 避免请求时上下文裁剪覆盖持久化对话历史 #8974
Changes from all commits
a4fe923
c81262c
36b1743
6564c82
d41658e
a9d543c
57bda47
69b3568
e6cfd6e
877f895
7f028cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,18 +53,50 @@ async def process( | |
| Returns: | ||
| The processed message list. | ||
| """ | ||
| result, _ = await self.process_with_meta(messages, trusted_token_usage) | ||
| return result | ||
|
|
||
| async def process_with_meta( | ||
| self, messages: list[Message], trusted_token_usage: int = 0 | ||
| ) -> tuple[list[Message], bool]: | ||
| try: | ||
| result = messages | ||
| was_hard_truncated = False | ||
|
|
||
| # 1. 基于轮次的截断 (Enforce max turns) | ||
| if self.config.enforce_max_turns != -1: | ||
| result = self.truncator.truncate_by_turns( | ||
| result, | ||
| keep_most_recent_turns=self.config.enforce_max_turns, | ||
| drop_turns=self.config.truncate_turns, | ||
| ) | ||
| non_system_count = len([m for m in result if m.role != "system"]) // 2 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Turn counting via This assumes every non-system message forms a strict user/assistant pair and that other roles (tool, function, etc.) either don’t appear or fit that pattern. With odd message counts or extra roles, turn counting can be wrong, leading to premature or delayed Suggested implementation: result = messages
was_hard_truncated = False
if self.config.enforce_max_turns != -1:
# Count turns in a way that is robust to extra roles (tool/function/etc.)
# A "turn" starts with a user message and continues through the
# assistant/tool/function responses until the next user message.
non_system_messages = [m for m in result if m.role != "system"]
turns = 0
in_turn = False
for message in non_system_messages:
if message.role == "user":
if not in_turn:
turns += 1
in_turn = True
elif message.role in ("assistant", "tool", "function"):
# These are considered responses within the current turn.
if in_turn:
in_turn = False
if turns > self.config.enforce_max_turns:
if isinstance(self.compressor, LLMSummaryCompressor):
logger.debug(
"Turn limit (%s) exceeded (%s turns), "
"delegating to LLM summary compressor instead of "
"hard truncation.",
self.config.enforce_max_turns,
turns,
)
compressed = await self.compressor(result)
if self.compressor.last_call_failed:
logger.warning(To fully align
|
||
| if non_system_count > self.config.enforce_max_turns: | ||
| if isinstance(self.compressor, LLMSummaryCompressor): | ||
| logger.debug( | ||
| "Turn limit (%s) exceeded (%s turns), " | ||
| "delegating to LLM summary compressor instead of " | ||
| "hard truncation.", | ||
| self.config.enforce_max_turns, | ||
| non_system_count, | ||
| ) | ||
| compressed = await self.compressor(result) | ||
| if self.compressor.last_call_failed: | ||
| logger.warning( | ||
| "LLM summary compression failed; falling back " | ||
| "to turn-based truncation to bound context " | ||
| "size.", | ||
| ) | ||
| result = self.truncator.truncate_by_turns( | ||
| result, | ||
| keep_most_recent_turns=self.config.enforce_max_turns, | ||
| drop_turns=self.config.truncate_turns, | ||
| ) | ||
| was_hard_truncated = True | ||
| else: | ||
| result = compressed | ||
| else: | ||
| result = self.truncator.truncate_by_turns( | ||
| result, | ||
| keep_most_recent_turns=self.config.enforce_max_turns, | ||
| drop_turns=self.config.truncate_turns, | ||
| ) | ||
| was_hard_truncated = True | ||
|
|
||
| # 2. 基于 token 的压缩 | ||
| if self.config.max_context_tokens > 0: | ||
| total_tokens = self.token_counter.count_tokens( | ||
| result, trusted_token_usage | ||
|
|
@@ -73,16 +105,19 @@ async def process( | |
| if self.compressor.should_compress( | ||
| result, total_tokens, self.config.max_context_tokens | ||
| ): | ||
| result = await self._run_compression(result, total_tokens) | ||
| result, compression_was_lossy = await self._run_compression( | ||
| result, total_tokens | ||
| ) | ||
| was_hard_truncated = was_hard_truncated or compression_was_lossy | ||
|
|
||
| return result | ||
| return result, was_hard_truncated | ||
| except Exception as e: | ||
| logger.error(f"Error during context processing: {e}", exc_info=True) | ||
| return messages | ||
| return messages, False | ||
|
|
||
| async def _run_compression( | ||
| self, messages: list[Message], prev_tokens: int | ||
| ) -> list[Message]: | ||
| ) -> tuple[list[Message], bool]: | ||
| """ | ||
| Compress/truncate the messages. | ||
|
|
||
|
|
@@ -95,7 +130,12 @@ async def _run_compression( | |
| """ | ||
| logger.debug("Compress triggered, starting compression...") | ||
|
|
||
| messages = await self.compressor(messages) | ||
| compressed = await self.compressor(messages) | ||
| was_lossy = ( | ||
| not isinstance(self.compressor, LLMSummaryCompressor) | ||
| or self.compressor.last_call_failed | ||
| ) | ||
|
lingyun14beta marked this conversation as resolved.
|
||
| messages = compressed | ||
|
|
||
| # double check | ||
| tokens_after_summary = self.token_counter.count_tokens(messages) | ||
|
|
@@ -117,5 +157,6 @@ async def _run_compression( | |
| ) | ||
| # still need compress, truncate by half | ||
| messages = self.truncator.truncate_by_halving(messages) | ||
| was_lossy = True | ||
|
|
||
| return messages | ||
| return messages, was_lossy | ||
Uh oh!
There was an error while loading. Please reload this page.