Skip to content

Commit 734b298

Browse files
declan-scaleclaude
andcommitted
fix(langgraph): accumulate multi-step usage in LangGraphTurn [greptile]
_capture overwrote self._usage on every AIMessage, so a multi-step turn (text -> tool decision -> final text) reported only the last LLM call's tokens and silently dropped the rest — undercounting in any billing/monitoring that reads turn.usage(). Accumulate additively across calls via _accumulate_turn_usage (None+None stays None; real 0 contributes 0). Add a test asserting summed input/output/total/cache/reasoning tokens across two AIMessages. The separate 06-18 "TurnResult.usage empty via auto_send_turn" comment is resolved by the foundation (emitter reads turn.usage() after stream exhaustion). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent eab5388 commit 734b298

2 files changed

Lines changed: 79 additions & 5 deletions

File tree

src/agentex/lib/adk/_modules/_langgraph_turn.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,33 @@ def langgraph_usage_to_turn_usage(usage_metadata: Any, model: str | None) -> Tur
6262
)
6363

6464

65+
def _add_optional(a: int | None, b: int | None) -> int | None:
66+
"""Sum two optional token counts; ``None`` means 'not reported' on that side.
67+
68+
``None + None`` stays ``None`` (model never reported usage), while a real 0
69+
contributes 0 (preserving zero counts rather than coercing them away).
70+
"""
71+
if a is None and b is None:
72+
return None
73+
return (a or 0) + (b or 0)
74+
75+
76+
def _accumulate_turn_usage(acc: TurnUsage, call: TurnUsage, model: str | None) -> TurnUsage:
77+
"""Add a single LLM call's usage into the running per-turn total.
78+
79+
A LangGraph turn can make multiple LLM calls (e.g. text -> tool decision ->
80+
final text); summing them avoids silently dropping all but the last call.
81+
"""
82+
return TurnUsage(
83+
model=model,
84+
input_tokens=_add_optional(acc.input_tokens, call.input_tokens),
85+
output_tokens=_add_optional(acc.output_tokens, call.output_tokens),
86+
total_tokens=_add_optional(acc.total_tokens, call.total_tokens),
87+
cached_input_tokens=_add_optional(acc.cached_input_tokens, call.cached_input_tokens),
88+
reasoning_tokens=_add_optional(acc.reasoning_tokens, call.reasoning_tokens),
89+
)
90+
91+
6592
class LangGraphTurn:
6693
"""HarnessTurn wrapping a LangGraph ``astream()`` event stream.
6794
@@ -89,7 +116,8 @@ class LangGraphTurn:
89116
option is needed.
90117
91118
Usage data is captured lazily via the ``on_final_ai_message`` callback and
92-
is only valid after ``events`` has been fully consumed.
119+
is only valid after ``events`` has been fully consumed. Multi-step turns
120+
(more than one LLM call) accumulate usage additively across calls.
93121
"""
94122

95123
def __init__(self, stream: Any, model: str | None = None) -> None:
@@ -105,15 +133,20 @@ async def _generate_events(self) -> AsyncGenerator[StreamTaskMessage, None]:
105133
def _capture(ai_msg: Any) -> None:
106134
usage_metadata = getattr(ai_msg, "usage_metadata", None)
107135
if usage_metadata is not None:
108-
self._usage = langgraph_usage_to_turn_usage(usage_metadata, self._model)
136+
call_usage = langgraph_usage_to_turn_usage(usage_metadata, self._model)
137+
# Accumulate across LLM calls — the callback fires once per agent
138+
# node invocation, so a multi-step turn reports usage more than
139+
# once; overwriting would drop all but the last call.
140+
self._usage = _accumulate_turn_usage(self._usage, call_usage, self._model)
109141

110142
async for ev in convert_langgraph_to_agentex_events(self._stream, on_final_ai_message=_capture):
111143
yield ev
112144

113145
def usage(self) -> TurnUsage:
114-
"""Return the usage captured from the last AIMessage in the stream.
146+
"""Return the usage accumulated across all AIMessages in the stream.
115147
116-
Valid only after ``events`` has been fully consumed.
117-
Returns a zero-usage ``TurnUsage`` if the model did not report usage.
148+
Multi-step turns sum each LLM call's usage. Valid only after ``events``
149+
has been fully consumed. Returns a zero-usage ``TurnUsage`` if the model
150+
did not report usage.
118151
"""
119152
return self._usage

tests/lib/adk/test_langgraph_turn.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,47 @@ async def test_usage_captured_from_ai_message(self):
168168
assert usage.total_tokens == 15
169169
assert usage.model == "gpt-4"
170170

171+
async def test_usage_accumulates_across_multiple_ai_messages(self):
172+
"""A multi-step turn (>1 LLM call) sums usage instead of keeping only the last."""
173+
from langchain_core.messages import AIMessage
174+
175+
first = AIMessage(
176+
content="thinking",
177+
usage_metadata={
178+
"input_tokens": 10,
179+
"output_tokens": 5,
180+
"total_tokens": 15,
181+
"input_token_details": {"cache_read": 2},
182+
"output_token_details": {"reasoning": 1},
183+
},
184+
)
185+
second = AIMessage(
186+
content="answer",
187+
usage_metadata={
188+
"input_tokens": 20,
189+
"output_tokens": 7,
190+
"total_tokens": 27,
191+
"input_token_details": {"cache_read": 3},
192+
"output_token_details": {"reasoning": 4},
193+
},
194+
)
195+
stream = _make_stream(
196+
[
197+
("updates", {"agent": {"messages": [first]}}),
198+
("updates", {"agent": {"messages": [second]}}),
199+
]
200+
)
201+
turn = LangGraphTurn(stream, model="gpt-4")
202+
await _drain(turn)
203+
204+
usage = turn.usage()
205+
assert usage.input_tokens == 30
206+
assert usage.output_tokens == 12
207+
assert usage.total_tokens == 42
208+
assert usage.cached_input_tokens == 5
209+
assert usage.reasoning_tokens == 5
210+
assert usage.model == "gpt-4"
211+
171212
async def test_usage_not_updated_when_no_usage_metadata(self):
172213
from langchain_core.messages import AIMessage
173214

0 commit comments

Comments
 (0)