Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions agentrun/server/agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
import pydash

from ..utils.helper import merge, MergeOptions
from ..utils.reasoning import (
get_reasoning_content,
is_thinking_enabled_from_env,
)
from ..utils.reasoning import get_reasoning_content
from .model import (
AgentEvent,
AgentRequest,
Expand Down Expand Up @@ -466,8 +463,6 @@ def _process_event_with_boundaries(
ToolCallStartEvent,
)

thinking_enabled = is_thinking_enabled_from_env()

# RAW 事件直接透传
if event.event == EventType.RAW:
raw_data = event.data.get("raw", "")
Expand All @@ -478,34 +473,31 @@ def _process_event_with_boundaries(
return

if event.event == EventType.REASONING:
if thinking_enabled:
reasoning_content = (
event.data.get("delta")
or get_reasoning_content(event.data)
or ""
)
if reasoning_content:
for sse_data in state.end_text_if_open(self._encoder):
yield sse_data
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=reasoning_content,
)
reasoning_content = (
event.data.get("delta")
or get_reasoning_content(event.data)
or ""
)
if reasoning_content:
for sse_data in state.end_text_if_open(self._encoder):
yield sse_data
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=reasoning_content,
)
return

# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
if event.event == EventType.TEXT:
addition = self._strip_reasoning_from_addition(
event.addition, thinking_enabled
)
addition = self._strip_reasoning_from_addition(event.addition)
addition_reasoning = get_reasoning_content(event.addition or {})
if thinking_enabled and addition_reasoning:
if addition_reasoning:
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
Expand Down Expand Up @@ -874,7 +866,6 @@ def _apply_addition(
def _strip_reasoning_from_addition(
self,
addition: Optional[Dict[str, Any]],
thinking_enabled: bool,
) -> Optional[Dict[str, Any]]:
if not addition:
return addition
Expand All @@ -890,8 +881,6 @@ def _strip_reasoning_from_addition(
else:
stripped.pop("additional_kwargs", None)

if not thinking_enabled:
return stripped
return stripped or None

async def _error_stream(self, message: str) -> AsyncIterator[str]:
Expand Down
48 changes: 19 additions & 29 deletions agentrun/server/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from fastapi.responses import JSONResponse, StreamingResponse
import pydash

from ..utils.reasoning import (
get_reasoning_content,
is_thinking_enabled_from_env,
)
from ..utils.reasoning import get_reasoning_content
from ..utils.helper import merge, MergeOptions
from .model import (
AgentEvent,
Expand Down Expand Up @@ -304,7 +301,6 @@ async def _format_stream(
# 状态追踪
sent_role = False
has_text = False
thinking_enabled = is_thinking_enabled_from_env()
tool_call_index = -1 # 从 -1 开始,第一个工具调用时变为 0
# 工具调用状态:{tool_id: {"started": bool, "index": int}}
tool_call_states: Dict[str, Dict[str, Any]] = {}
Expand Down Expand Up @@ -341,19 +337,18 @@ async def _format_stream(
event.addition_merge_options,
)

self._apply_reasoning_gate(delta, thinking_enabled)
self._promote_reasoning_content(delta)
yield self._build_chunk(context, delta)
continue

if event.event == EventType.REASONING:
if thinking_enabled:
reasoning_content = event.data.get("delta", "")
if reasoning_content:
has_text = True
yield self._build_chunk(
context,
{"reasoning_content": reasoning_content},
)
reasoning_content = event.data.get("delta", "")
if reasoning_content:
has_text = True
yield self._build_chunk(
Comment on lines 344 to +348
context,
{"reasoning_content": reasoning_content},
)
continue

# TOOL_CALL_CHUNK 事件
Expand Down Expand Up @@ -401,7 +396,7 @@ async def _format_stream(
event.addition_merge_options,
)

self._apply_reasoning_gate(delta, thinking_enabled)
self._promote_reasoning_content(delta)
yield self._build_chunk(context, delta)
continue

Expand Down Expand Up @@ -477,7 +472,6 @@ def _format_non_stream(
"""
content_parts: List[str] = []
reasoning_parts: List[str] = []
thinking_enabled = is_thinking_enabled_from_env()
# 工具调用状态:{tool_id: {id, name, arguments}}
tool_call_map: Dict[str, Dict[str, Any]] = {}
has_tool_calls = False
Expand All @@ -486,12 +480,12 @@ def _format_non_stream(
if event.event == EventType.TEXT:
content_parts.append(event.data.get("delta", ""))
reasoning_content = get_reasoning_content(event.addition or {})
if thinking_enabled and reasoning_content:
if reasoning_content:
reasoning_parts.append(reasoning_content)
Comment on lines 480 to 484

elif event.event == EventType.REASONING:
reasoning_content = event.data.get("delta", "")
if thinking_enabled and reasoning_content:
if reasoning_content:
reasoning_parts.append(reasoning_content)

elif event.event == EventType.TOOL_CALL_CHUNK:
Expand Down Expand Up @@ -564,18 +558,14 @@ def _apply_addition(

return merge(delta, addition, **(merge_options or {}))

def _apply_reasoning_gate(
self,
payload: Dict[str, Any],
thinking_enabled: bool,
) -> None:
if thinking_enabled:
reasoning_content = get_reasoning_content(payload)
if reasoning_content is not None:
payload["reasoning_content"] = reasoning_content
return

def _promote_reasoning_content(self, payload: Dict[str, Any]) -> None:
reasoning_content = get_reasoning_content(payload)
payload.pop("reasoning_content", None)
additional_kwargs = payload.get("additional_kwargs")
if isinstance(additional_kwargs, dict):
additional_kwargs.pop("reasoning_content", None)
if not additional_kwargs:
payload.pop("additional_kwargs", None)

if reasoning_content:
payload["reasoning_content"] = reasoning_content
25 changes: 7 additions & 18 deletions tests/e2e/test_reasoning_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,13 @@ async def test_openai_stream_reasoning_content_gate(
assert response.status_code == 200
events = _parse_sse_events(response.text)
deltas = [
(event.get("choices") or [{}])[0].get("delta") or {}
for event in events
(event.get("choices") or [{}])[0].get("delta") or {} for event in events
]
reasoning = "".join(delta.get("reasoning_content", "") for delta in deltas)
content = "".join(delta.get("content", "") for delta in deltas)

assert content == "answer"
assert reasoning == ("thinking" if thinking_enabled else "")
assert reasoning == "thinking"
assert all("additional_kwargs" not in delta for delta in deltas)


Expand All @@ -104,10 +103,7 @@ async def test_openai_non_stream_reasoning_content_gate(
assert response.status_code == 200
message = response.json()["choices"][0]["message"]
assert message["content"] == "answer"
if thinking_enabled:
assert message["reasoning_content"] == "thinking"
else:
assert "reasoning_content" not in message
assert message["reasoning_content"] == "thinking"


@pytest.mark.parametrize("thinking_enabled", [True, False])
Expand Down Expand Up @@ -140,14 +136,7 @@ async def test_agui_reasoning_events_gate(
)

assert content == "answer"
if thinking_enabled:
assert reasoning == "thinking"
assert event_types.index("REASONING_MESSAGE_CONTENT") < event_types.index(
"TEXT_MESSAGE_START"
)
else:
assert reasoning == ""
assert all(
not event_type.startswith("REASONING")
for event_type in event_types
)
assert reasoning == "thinking"
assert event_types.index("REASONING_MESSAGE_CONTENT") < event_types.index(
"TEXT_MESSAGE_START"
)
29 changes: 18 additions & 11 deletions tests/unittests/server/test_agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ async def invoke_agent(request: AgentRequest):


class TestAGUIReasoningContent:
"""测试 AG-UI reasoning 事件输出开关"""
"""测试 AG-UI reasoning 事件输出"""

def get_client(self, invoke_agent):
server = AgentRunServer(invoke_agent=invoke_agent)
Expand Down Expand Up @@ -1228,7 +1228,7 @@ async def invoke_agent(request: AgentRequest):
assert reasoning_event["delta"] == "thinking"
assert "TEXT_MESSAGE_CONTENT" in types

def test_stream_suppresses_reasoning_when_thinking_disabled(
def test_stream_includes_reasoning_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1246,9 +1246,14 @@ async def invoke_agent(request: AgentRequest):
)

events = _agui_sse_events(response)
assert "REASONING_MESSAGE_CONTENT" not in [
event["type"] for event in events
]
types = [event["type"] for event in events]
reasoning_event = next(
event
for event in events
if event["type"] == "REASONING_MESSAGE_CONTENT"
)
assert "REASONING_START" in types
assert reasoning_event["delta"] == "thinking"
text_event = next(
event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT"
)
Expand All @@ -1257,7 +1262,7 @@ async def invoke_agent(request: AgentRequest):
def test_stream_promotes_chunk_additional_kwargs_reasoning(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}')
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')

async def invoke_agent(request: AgentRequest):
yield SimpleNamespace(
Expand All @@ -1282,9 +1287,7 @@ async def invoke_agent(request: AgentRequest):
assert reasoning_event["delta"] == "thinking"
assert text_event["delta"] == "answer"

def test_text_addition_reasoning_is_emitted_before_text(
self, monkeypatch
):
def test_text_addition_reasoning_is_emitted_before_text(self, monkeypatch):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}')

async def invoke_agent(request: AgentRequest):
Expand Down Expand Up @@ -1314,7 +1317,7 @@ async def invoke_agent(request: AgentRequest):
assert text_event["delta"] == "answer"
assert "additional_kwargs" not in text_event

def test_text_addition_reasoning_is_stripped_when_thinking_disabled(
def test_text_addition_reasoning_is_emitted_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1335,7 +1338,11 @@ async def invoke_agent(request: AgentRequest):

events = _agui_sse_events(response)
types = [event["type"] for event in events]
assert all(not event_type.startswith("REASONING") for event_type in types)
assert types.index("REASONING_MESSAGE_CONTENT") < types.index(
"TEXT_MESSAGE_START"
)
assert "REASONING_MESSAGE_END" in types
assert "REASONING_END" in types
text_event = next(
event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT"
)
Expand Down
Loading
Loading