Skip to content

Commit 75df4e6

Browse files
declan-scaleclaude
andcommitted
refactor(pydantic-ai): drop coalesce_tool_requests workaround — foundation auto_send delivers streamed tool requests natively (AGX1-377/378)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c264e37 commit 75df4e6

6 files changed

Lines changed: 77 additions & 203 deletions

File tree

src/agentex/lib/adk/_modules/_pydantic_ai_async.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,10 @@ async def stream_pydantic_ai_events(
5151
from agentex.lib.core.harness.emitter import UnifiedEmitter
5252
from agentex.lib.adk._modules._pydantic_ai_turn import PydanticAITurn
5353

54-
# coalesce_tool_requests=True is a temporary workaround (AGX1-377): the
55-
# foundation auto_send currently DROPS tool requests delivered as the
56-
# streamed Start+ToolRequestDelta+Done shape, so we collapse them into a
57-
# single Full here. Once auto_send handles the streamed tool-request shape
58-
# natively, this flag should be removed and the default (off) used.
5954
turn = PydanticAITurn(
6055
stream,
6156
model=None,
6257
tracing_handler=tracing_handler,
63-
coalesce_tool_requests=True,
6458
)
6559
emitter = UnifiedEmitter(
6660
task_id=task_id,

src/agentex/lib/adk/_modules/_pydantic_ai_turn.py

Lines changed: 7 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from __future__ import annotations
1717

18-
import json
1918
from typing import TYPE_CHECKING, Any, AsyncIterator
2019

2120
from pydantic_ai.run import AgentRunResultEvent
@@ -74,105 +73,29 @@ def pydantic_ai_usage_to_turn_usage(usage: Any, model: str | None) -> TurnUsage:
7473
)
7574

7675

77-
async def _coalesce_tool_requests(
78-
source: AsyncIterator[StreamTaskMessage],
79-
) -> AsyncIterator[StreamTaskMessage]:
80-
"""Convert Start(tool_request)+deltas+Done into Full(tool_request).
81-
82-
``convert_pydantic_ai_to_agentex_events`` emits ``Start+Done`` for tool
83-
calls (enabling streaming of argument tokens over the sync/HTTP channel).
84-
The async/auto_send delivery path does not stream tool-call arguments —
85-
it uses Option A (full messages). This wrapper coalesces the Start+Done
86-
sequence into a single ``StreamTaskMessageFull``, matching the shape that
87-
``auto_send`` expects and that the harness conformance tests are designed for.
88-
89-
Argument delta fragments (``ToolRequestDelta.arguments_delta``) are
90-
accumulated as a JSON string and parsed back into a dict. If parsing
91-
fails, the raw string is stored under ``"_raw"`` so no information is lost.
92-
"""
93-
from agentex.types.tool_request_delta import ToolRequestDelta
94-
from agentex.types.tool_request_content import ToolRequestContent
95-
96-
# pending[index] = (ToolRequestContent, accumulated_args_delta_str)
97-
pending: dict[Any, tuple[Any, str]] = {}
98-
99-
async for event in source:
100-
if isinstance(event, StreamTaskMessageStart):
101-
ctype = getattr(event.content, "type", None)
102-
if ctype == "tool_request":
103-
# Stage; do not yield — replaced by Full on Done.
104-
pending[event.index] = (event.content, "")
105-
continue
106-
107-
elif isinstance(event, StreamTaskMessageDelta):
108-
if event.index in pending and isinstance(event.delta, ToolRequestDelta):
109-
# Accumulate argument delta fragments; don't yield.
110-
content, accum = pending[event.index]
111-
pending[event.index] = (content, accum + (event.delta.arguments_delta or ""))
112-
continue
113-
114-
elif isinstance(event, StreamTaskMessageDone):
115-
if event.index in pending:
116-
content, args_delta = pending.pop(event.index)
117-
# Build final arguments: merge initial dict with accumulated delta.
118-
base_args: dict[str, Any] = {}
119-
if isinstance(content, ToolRequestContent):
120-
base_args = dict(content.arguments) if content.arguments else {}
121-
122-
if args_delta:
123-
try:
124-
parsed = json.loads(args_delta)
125-
if isinstance(parsed, dict):
126-
base_args.update(parsed)
127-
else:
128-
base_args["_raw"] = args_delta
129-
except json.JSONDecodeError:
130-
base_args["_raw"] = args_delta
131-
132-
# Emit as Full with the complete arguments.
133-
full_content = (
134-
ToolRequestContent(
135-
type="tool_request",
136-
author=content.author if isinstance(content, ToolRequestContent) else "agent",
137-
tool_call_id=content.tool_call_id if isinstance(content, ToolRequestContent) else "",
138-
name=content.name if isinstance(content, ToolRequestContent) else "",
139-
arguments=base_args,
140-
)
141-
if isinstance(content, ToolRequestContent)
142-
else content
143-
)
144-
yield StreamTaskMessageFull(type="full", index=event.index, content=full_content)
145-
continue
146-
147-
yield event
148-
149-
15076
class PydanticAITurn:
15177
"""A single harness turn backed by a pydantic-ai event stream.
15278
15379
Satisfies the ``HarnessTurn`` protocol: ``events`` async-generates the
15480
canonical ``StreamTaskMessage*`` stream; ``usage()`` returns a normalized
15581
``TurnUsage`` (valid only after ``events`` is exhausted).
15682
157-
By default ``events`` is identical to the bare
158-
``convert_pydantic_ai_to_agentex_events`` output (tool calls stream as
159-
``Start + ToolRequestDelta + Done``, preserving argument-token streaming on
160-
the sync/yield channel). When ``coalesce_tool_requests=True``, tool-request
161-
sequences are collapsed into a single ``StreamTaskMessageFull`` (Option A —
162-
no streaming of argument tokens) for the async/auto_send path.
83+
``events`` is identical to the bare ``convert_pydantic_ai_to_agentex_events``
84+
output (tool calls stream as ``Start + ToolRequestDelta + Done``, preserving
85+
argument-token streaming on the sync/yield channel). The foundation
86+
``auto_send`` delivers the streamed tool-request shape natively (AGX1-377),
87+
so no coalescing is needed on either channel.
16388
"""
16489

16590
def __init__(
16691
self,
16792
stream: AsyncIterator[Any],
16893
model: str | None = None,
16994
tracing_handler: "AgentexPydanticAITracingHandler | None" = None,
170-
coalesce_tool_requests: bool = False,
17195
) -> None:
17296
self._stream = stream
17397
self._model = model
17498
self._tracing_handler = tracing_handler
175-
self._coalesce_tool_requests = coalesce_tool_requests
17699
self._usage = TurnUsage(model=model)
177100

178101
@property
@@ -199,12 +122,8 @@ def _capture(result_event: AgentRunResultEvent) -> None:
199122
tracing_handler=self._tracing_handler,
200123
on_result=_capture,
201124
)
202-
if self._coalesce_tool_requests:
203-
async for ev in _coalesce_tool_requests(raw_stream):
204-
yield ev
205-
else:
206-
async for ev in raw_stream:
207-
yield ev
125+
async for ev in raw_stream:
126+
yield ev
208127

209128
def usage(self) -> TurnUsage:
210129
"""Return the normalized usage for this turn.

tests/lib/adk/test_pydantic_ai_async.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class FakeStreamingModule:
8282
def __init__(self) -> None:
8383
self.contexts: list[FakeContext] = []
8484

85-
def streaming_task_message_context(self, *, task_id: str, initial_content: Any) -> FakeContext:
85+
def streaming_task_message_context(
86+
self, *, task_id: str, initial_content: Any, streaming_mode: str = "coalesced", created_at: Any = None
87+
) -> FakeContext:
8688
tm = TaskMessage(
8789
id=f"m{len(self.contexts) + 1}",
8890
task_id=task_id,
@@ -255,26 +257,25 @@ async def test_empty_thinking_delta_is_skipped(
255257

256258

257259
class TestToolCallEmission:
258-
async def test_tool_call_emits_full_tool_request_message_on_part_end(
260+
async def test_tool_call_opens_streaming_context_with_identity(
259261
self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule]
260262
) -> None:
261-
"""Tool requests are posted as full messages (open+close on streaming context).
263+
"""Tool requests are delivered as a streaming context (Start+Delta+Done).
262264
263-
AGX1-373 envelope change: tool messages now arrive via
264-
streaming_task_message_context (open+close pair) instead of
265-
adk.messages.create. The logical content (tool_call_id, name,
266-
arguments, author) is identical; only the delivery channel changed.
265+
AGX1-377 fix: auto_send now delivers streamed tool-request messages
266+
natively (Start+ToolRequestDelta+Done). The streaming context is opened
267+
at the Start event with the initial ToolRequestContent (tool_call_id +
268+
name + empty arguments), argument tokens are streamed as deltas, and the
269+
context is closed on Done.
267270
268271
This test uses a realistic pydantic-ai event sequence: args arrive as a
269272
PartDeltaEvent fragment (the way OpenAI/Anthropic actually stream JSON
270-
tool-call arguments). The new implementation accumulates them correctly.
271-
272-
Parts-manager invariant: PartEnd.part is the accumulated snapshot; real
273-
pydantic-ai conveys args via PartStart + PartDeltaEvent, so a
274-
PartStart(None)+PartEnd(json) with no delta is not realizable.
273+
tool-call arguments).
275274
"""
276275
from pydantic_ai.messages import ToolCallPartDelta
277276

277+
from agentex.types.tool_request_delta import ToolRequestDelta
278+
278279
streaming, messages = fake_adk
279280
events = [
280281
PartStartEvent(
@@ -293,23 +294,27 @@ async def test_tool_call_emits_full_tool_request_message_on_part_end(
293294
]
294295
await stream_pydantic_ai_events(_aiter(events), TASK_ID)
295296

296-
# AGX1-373: tool messages arrive via streaming_task_message_context,
297-
# NOT via adk.messages.create.
298-
assert messages.created == [], "adk.messages.create must not be called after reimplementation"
299-
assert len(streaming.contexts) == 1, "tool_request opens a streaming context (open+close)"
297+
# AGX1-373: tool messages arrive via streaming_task_message_context.
298+
assert messages.created == [], "adk.messages.create must not be called"
299+
assert len(streaming.contexts) == 1, "tool_request opens a streaming context"
300300
ctx = streaming.contexts[0]
301301
assert ctx.closed is True
302302
content = ctx.initial_content
303303
assert isinstance(content, ToolRequestContent)
304304
assert content.tool_call_id == "c1"
305305
assert content.name == "get_weather"
306-
assert content.arguments == {"city": "Paris"}
307306
assert content.author == "agent"
308-
assert ctx.updates == [], "open+close only — no deltas for tool messages"
307+
# AGX1-377 streamed shape: initial_content has empty args (args come via delta)
308+
assert content.arguments == {}
309+
# The arg delta is delivered as a stream_update
310+
assert len(ctx.updates) == 1
311+
assert isinstance(ctx.updates[0].delta, ToolRequestDelta)
312+
assert ctx.updates[0].delta.arguments_delta == '{"city":"Paris"}'
309313

310314
async def test_tool_call_with_dict_args_passes_through(
311315
self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule]
312316
) -> None:
317+
"""When args arrive pre-populated as a dict in PartStart, they're in initial_content."""
313318
streaming, messages = fake_adk
314319
events = [
315320
PartStartEvent(
@@ -326,25 +331,26 @@ async def test_tool_call_with_dict_args_passes_through(
326331
# AGX1-373: tool messages via streaming_task_message_context
327332
assert messages.created == []
328333
assert len(streaming.contexts) == 1
334+
# Dict args present at PartStart land directly in initial_content.arguments
329335
assert streaming.contexts[0].initial_content.arguments == {"q": "weather"}
336+
assert streaming.contexts[0].updates == [], "no delta for pre-populated dict args"
330337

331338
async def test_tool_call_with_invalid_json_args_surfaces_raw(
332339
self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule]
333340
) -> None:
334-
"""Don't drop the tool call when the model emits malformed JSON args.
341+
"""Malformed JSON arg delta is surfaced as a ToolRequestDelta with the raw string.
335342
336-
The arguments field is preserved under ``_raw`` so the failure is
337-
visible to the UI rather than silently truncated.
338-
339-
Uses a PartDeltaEvent to deliver the invalid string (the way pydantic-ai
340-
actually surfaces arg tokens) so the coalescer picks it up.
343+
The argument delta is delivered as-is by auto_send; the client-side
344+
accumulator or the streaming backend handles malformed JSON gracefully.
341345
342346
Parts-manager invariant: PartEnd.part is the accumulated snapshot; real
343347
pydantic-ai conveys args via PartStart + PartDeltaEvent, so a
344348
PartStart(None)+PartEnd(json) with no delta is not realizable.
345349
"""
346350
from pydantic_ai.messages import ToolCallPartDelta
347351

352+
from agentex.types.tool_request_delta import ToolRequestDelta
353+
348354
streaming, messages = fake_adk
349355
events = [
350356
PartStartEvent(
@@ -366,7 +372,13 @@ async def test_tool_call_with_invalid_json_args_surfaces_raw(
366372
# AGX1-373: tool messages via streaming_task_message_context
367373
assert messages.created == []
368374
assert len(streaming.contexts) == 1
369-
assert streaming.contexts[0].initial_content.arguments == {"_raw": "not-json{"}
375+
ctx = streaming.contexts[0]
376+
# Initial content has empty args (args come via delta)
377+
assert ctx.initial_content.arguments == {}
378+
# The malformed JSON is surfaced verbatim in the ToolRequestDelta
379+
assert len(ctx.updates) == 1
380+
assert isinstance(ctx.updates[0].delta, ToolRequestDelta)
381+
assert ctx.updates[0].delta.arguments_delta == "not-json{"
370382

371383
async def test_tool_call_with_none_args_defaults_to_empty_dict(
372384
self, fake_adk: tuple[FakeStreamingModule, FakeMessagesModule]
@@ -388,6 +400,7 @@ async def test_tool_call_with_none_args_defaults_to_empty_dict(
388400
assert messages.created == []
389401
assert len(streaming.contexts) == 1
390402
assert streaming.contexts[0].initial_content.arguments == {}
403+
assert streaming.contexts[0].updates == [], "no delta when args are absent"
391404

392405

393406
class TestToolResult:

tests/lib/adk/test_pydantic_ai_turn.py

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,17 @@ async def test_no_usage_event_leaves_default_usage(self):
230230
assert usage.num_llm_calls == 0
231231

232232

233-
class TestToolRequestCoalescing:
234-
"""The ``coalesce_tool_requests`` flag controls tool-call delivery shape.
233+
class TestToolRequestStreaming:
234+
"""PydanticAITurn.events equals the bare converter output unconditionally.
235235
236-
Default (off): tool calls stream as Start + ToolRequestDelta + Done,
237-
matching the bare converter and preserving argument-token streaming on the
238-
sync/yield channel.
239-
240-
On: tool-request sequences collapse into a single StreamTaskMessageFull
241-
(Option A) for the async/auto_send path (a temporary AGX1-377 workaround).
236+
The foundation auto_send delivers Start+ToolRequestDelta+Done natively
237+
(AGX1-377), so no coalescing is needed on either channel.
242238
"""
243239

244-
async def test_default_off_matches_bare_converter_for_streamed_tool_call(self):
245-
"""Default Turn (coalesce off) yields a ToolRequestDelta for a streamed-args
246-
tool call — i.e. it is byte-for-byte the bare converter output, so the
247-
sync/yield channel keeps its argument-token streaming."""
240+
async def test_events_match_bare_converter_for_streamed_tool_call(self):
241+
"""PydanticAITurn yields a ToolRequestDelta for a streamed-args tool call
242+
— i.e. it is byte-for-byte the bare converter output, preserving
243+
argument-token streaming on the sync/yield channel."""
248244
from pydantic_ai.messages import ToolCallPart, ToolCallPartDelta
249245

250246
from agentex.types.tool_request_delta import ToolRequestDelta
@@ -265,57 +261,16 @@ async def test_default_off_matches_bare_converter_for_streamed_tool_call(self):
265261

266262
bare_out = await _collect(convert_pydantic_ai_to_agentex_events(_aiter(tool_events)))
267263

268-
# Default Turn is identical to the bare converter.
264+
# Turn is identical to the bare converter.
269265
assert len(turn_out) == len(bare_out)
270266
for a, b in zip(turn_out, bare_out):
271267
assert type(a) is type(b)
272268
assert a.model_dump() == b.model_dump()
273269

274-
# And the arg-streaming delta is present (not coalesced away).
270+
# The arg-streaming delta is present.
275271
deltas = [
276272
e for e in turn_out if isinstance(e, StreamTaskMessageDelta) and isinstance(e.delta, ToolRequestDelta)
277273
]
278-
assert len(deltas) == 1, "streamed tool-call args must surface as a ToolRequestDelta when coalesce is off"
274+
assert len(deltas) == 1, "streamed tool-call args must surface as a ToolRequestDelta"
279275
assert isinstance(deltas[0].delta, ToolRequestDelta)
280276
assert deltas[0].delta.arguments_delta == '{"city":"Paris"}'
281-
282-
async def test_coalesce_on_emits_single_full_with_accumulated_args_and_no_delta(self):
283-
"""coalesce_tool_requests=True yields one StreamTaskMessageFull(tool_request)
284-
with fully-accumulated arguments and NO ToolRequestDelta."""
285-
from pydantic_ai.messages import ToolCallPart, ToolCallPartDelta
286-
287-
from agentex.types.tool_request_delta import ToolRequestDelta
288-
from agentex.types.task_message_update import (
289-
StreamTaskMessageDone,
290-
StreamTaskMessageFull,
291-
StreamTaskMessageDelta,
292-
StreamTaskMessageStart,
293-
)
294-
from agentex.types.tool_request_content import ToolRequestContent
295-
296-
tool_events = [
297-
PartStartEvent(index=0, part=ToolCallPart(tool_name="get_weather", args=None, tool_call_id="c1")),
298-
PartDeltaEvent(index=0, delta=ToolCallPartDelta(args_delta='{"city":"Paris"}')),
299-
PartEndEvent(
300-
index=0,
301-
part=ToolCallPart(tool_name="get_weather", args='{"city":"Paris"}', tool_call_id="c1"),
302-
),
303-
]
304-
305-
turn = PydanticAITurn(_aiter(tool_events), model="openai:gpt-4o", coalesce_tool_requests=True)
306-
turn_out = await _collect(turn.events)
307-
308-
# Exactly one event: a Full(tool_request). No Start/Delta/Done leak through.
309-
assert len(turn_out) == 1
310-
full = turn_out[0]
311-
assert isinstance(full, StreamTaskMessageFull)
312-
assert isinstance(full.content, ToolRequestContent)
313-
assert full.content.tool_call_id == "c1"
314-
assert full.content.name == "get_weather"
315-
assert full.content.arguments == {"city": "Paris"}, "args must be fully accumulated"
316-
317-
assert not any(isinstance(e, StreamTaskMessageStart) for e in turn_out)
318-
assert not any(isinstance(e, StreamTaskMessageDone) for e in turn_out)
319-
assert not any(
320-
isinstance(e, StreamTaskMessageDelta) and isinstance(e.delta, ToolRequestDelta) for e in turn_out
321-
), "no ToolRequestDelta when coalescing is on"

0 commit comments

Comments
 (0)