From a4851cf1116e7cee72ae80a10e21b7fea3c69d73 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 9 Jun 2026 09:23:09 -0700 Subject: [PATCH 1/2] [INITIAL] Update [ghstack-poisoned] --- examples/models/qwen3_5_moe/README.md | 10 + extension/llm/server/cpp/worker_loop.h | 83 +++++-- extension/llm/server/python/README.md | 1 + .../llm/server/python/openai_transcript.py | 175 +++++++++++++++ extension/llm/server/python/serving_chat.py | 70 +++++- .../llm/server/python/tests/test_sessions.py | 203 ++++++++++++++++++ 6 files changed, 520 insertions(+), 22 deletions(-) create mode 100644 extension/llm/server/python/openai_transcript.py diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 0583765cb77..899c816e859 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -243,6 +243,16 @@ Each `done` event reports (`new`/`exact_prefix`/`dirty`/`mismatch`/`equal`) for measuring the hit rate. `--no-warm-resume` forces a full prefill every request (for A/B comparison). +**Tool-call turns (token-ID continuation):** an assistant turn re-rendered from +its parsed tool call rarely re-tokenizes to the tokens the model actually +generated, so plain warm resume misses on agent loops. The server stores the +exact generated token ids per session and, on the next turn, sends the prompt as +segments (`{"text"}` / `{"ids"}`) that splice those ids back in for prior +assistant turns instead of re-rendering them — so the resident state stays an +exact token prefix and resume hits. Tool *results* remain text (re-tokenized +deterministically). The worker's exact-token check still backstops everything, so +a mismatch just falls back to a full prefill. + This is **isolation + warm resume, not concurrency**: execution is still synchronous (one in-flight request; `--num-runners > 1` is rejected since more workers would duplicate the weights). Fair interleaving across in-flight requests diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h index 7f92e60371e..f580d21d356 100644 --- a/extension/llm/server/cpp/worker_loop.h +++ b/extension/llm/server/cpp/worker_loop.h @@ -42,8 +42,11 @@ // worker -> stdout, once: {"ready": true, "max_sessions": int, // "max_named_sessions": int} // client -> stdin: -// generate: {"prompt": str, "max_new_tokens": int, "temperature": float, -// "stop": [str, ...], "session_id"?: str} +// generate: {"max_new_tokens": int, "temperature": float, +// "stop": [str, ...], "session_id"?: str, +// and exactly one prompt form: +// "prompt": str +// "prompt_segments": [{"text": str} | {"ids": [int, ...]}]} // open: {"op": "open", "session_id": str} // close: {"op": "close", "session_id": str} // reset: {"op": "reset", "session_id": str} // clear context, keep @@ -55,7 +58,9 @@ // "finish_reason": "stop"|"length", // "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, // "session_reset_reason": "new"|"exact_prefix"|"dirty"| -// "mismatch"|"equal"} +// "mismatch"|"equal", +// "generated_token_ids"?: [int, ...]} // omitted if +// stop-trimmed // open: {"opened": true, "session_id": str} // close: {"closed": true, "session_id": str} // reset: {"reset": true, "session_id": str} @@ -122,7 +127,6 @@ inline void worker_handle_request( const std::unordered_map& metadata, const nlohmann::json& req) { LLMSession& session = *st.session; - const std::string prompt = req.at("prompt").get(); int64_t max_new = req.value("max_new_tokens", static_cast(-1)); const float temperature = req.value("temperature", 0.0f); // Stop strings (the request's `stop` sequences): terminate at the token @@ -131,13 +135,43 @@ inline void worker_handle_request( const std::vector stops = req.value("stop", std::vector{}); - // No special tokens: the prompt is already rendered (the control plane - // applied the chat template), matching the runner's own encode path. - auto encode_result = tokenizer.encode(prompt, /*bos=*/0, /*eos=*/0); - if (!encode_result.ok()) { - throw std::runtime_error("prompt encode failed"); + // The prompt is either a single rendered string ("prompt") or an ordered list + // of segments ("prompt_segments"), each a {"text": ...} chunk to tokenize or + // a + // {"ids": [...]} run of literal token ids. Segments let the control plane + // splice the exact generated token ids of prior assistant turns back in, + // instead of re-tokenizing the chat template's lossy re-rendering of them (so + // warm resume can hit on tool-use turns). Text is encoded with no special + // tokens (already rendered), matching the runner's own encode path. + const bool has_prompt = req.contains("prompt"); + const bool has_segments = req.contains("prompt_segments"); + if (has_prompt == has_segments) { + throw std::runtime_error( + "exactly one of prompt / prompt_segments is required"); + } + std::vector ids; + auto encode_text = [&](const std::string& text) { + auto enc = tokenizer.encode(text, /*bos=*/0, /*eos=*/0); + if (!enc.ok()) { + throw std::runtime_error("prompt encode failed"); + } + ids.insert(ids.end(), enc->begin(), enc->end()); + }; + if (has_segments) { + for (const auto& seg : req.at("prompt_segments")) { + if (seg.contains("ids")) { + for (const auto& id : seg.at("ids")) { + ids.push_back(id.get()); + } + } else if (seg.contains("text")) { + encode_text(seg.at("text").get()); + } else { + throw std::runtime_error("prompt_segment needs `text` or `ids`"); + } + } + } else { + encode_text(req.at("prompt").get()); } - std::vector ids = std::move(*encode_result); if (ids.empty()) { throw std::runtime_error("empty prompt"); } @@ -249,14 +283,27 @@ inline void worker_handle_request( // "length" -- it ran to max_new (possibly clamped to the context window). // reused/prefilled sum to prompt_tokens; session_reset_reason explains the // prefill plan (for measuring warm-resume hit rate). - worker_emit( - {{"done", true}, - {"prompt_tokens", num_prompt}, - {"completion_tokens", num_generated}, - {"finish_reason", finish}, - {"reused_prompt_tokens", reused}, - {"prefilled_prompt_tokens", prefilled}, - {"session_reset_reason", plan.reason}}); + nlohmann::json done = { + {"done", true}, + {"prompt_tokens", num_prompt}, + {"completion_tokens", num_generated}, + {"finish_reason", finish}, + {"reused_prompt_tokens", reused}, + {"prefilled_prompt_tokens", prefilled}, + {"session_reset_reason", plan.reason}}; + // generated_token_ids = the (non-terminal) tokens made resident this turn, + // for the control plane to splice back as an `ids` segment. Only emit them + // when they faithfully decode to the emitted text: a stop-string trim kept + // the post-stop tokens resident but dropped them from the output, so splicing + // them would inject text the client never saw. Omitting them makes the + // control plane record this turn as not resumable (falls back to a text + // re-render). + if (!stop_string) { + done["generated_token_ids"] = std::vector( + st.resident_token_ids.end() - num_generated, + st.resident_token_ids.end()); + } + worker_emit(done); } // Owns the engine's sessions for one worker: named sessions keyed by id plus a diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md index e14e6176c81..4b17dc1f37f 100644 --- a/extension/llm/server/python/README.md +++ b/extension/llm/server/python/README.md @@ -137,6 +137,7 @@ does blocking pipe I/O on its executor thread. | `chat_template.py` | messages (+tools) → prompt string | | `worker_client.py` | spawn a worker process + drive it over JSONL (raw transport) | | `session_runtime.py` | stateful runtime over one worker: open/generate/reset/close + streaming bridge | +| `openai_transcript.py` | OpenAI token-ID warm-resume state (fingerprints + sentinel splicing) | | `serving_chat.py` | `/v1/chat/completions` OpenAI adapter (streaming + non-streaming, stop, tools) | | `tool_parsers/` | Hermes/Qwen `` parser only | | `cpp/text_llm_worker.cpp` | the generic C++ worker binary | diff --git a/extension/llm/server/python/openai_transcript.py b/extension/llm/server/python/openai_transcript.py new file mode 100644 index 00000000000..7f128bbad4b --- /dev/null +++ b/extension/llm/server/python/openai_transcript.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI/chat-template transcript state for token-ID warm resume (V2b.1.5). + +This is the OpenAI-adapter-specific glue that makes warm resume work across the +chat template's lossy re-render of prior assistant turns (especially tool calls, +which re-render from parsed structure and don't re-tokenize to what the model +generated). It is NOT generic runtime infrastructure: it knows ChatMessages, +tool_calls, the ChatTemplate, sentinels, and assistant fingerprints. The runtime +(session_runtime) only sees PromptInput. + +Per session we keep one record per assistant turn we produced, in order: +{"fp": fingerprint of the response we returned, "ids": exact generated token ids +| None}. On the next request each prior assistant turn is replaced with a unique +sentinel, the conversation is rendered once, and the rendered text is split on +the sentinels with the stored ids spliced back in -- but only for turns whose +fingerprint matches the incoming message (so an edited/branched history, or a +session reused for another conversation, is never substituted with stale ids) +and whose ids are present (a stop-trimmed turn has None and is left as text). +Everything is backstopped by the worker's exact-token prefix check. +""" + +import hashlib +import json +import re +import uuid +from typing import Optional + +from .chat_template import ChatTemplate +from .protocol import ChatMessage +from .session_runtime import PromptInput + + +class OpenAITranscriptState: + def __init__(self, template: ChatTemplate): + self._template = template + # session_id -> [{"fp": str, "ids": list[int] | None}, ...] (one per + # assistant turn we produced, in order). Cleared on reset/close. + self._turns: dict[str, list[dict]] = {} + + @staticmethod + def _assistant_fingerprint(content, tool_calls) -> str: + """Stable fingerprint of an assistant turn's semantic payload (content + + each tool call's name/arguments; the random call id is ignored). Used to + confirm an incoming assistant message is the one we generated before + splicing its stored token ids.""" + norm = [] + for tc in tool_calls or []: + fn = getattr(tc, "function", None) + if fn is not None: + norm.append([getattr(fn, "name", None), getattr(fn, "arguments", None)]) + elif isinstance(tc, dict): + f = tc.get("function", {}) + norm.append([f.get("name"), f.get("arguments")]) + blob = json.dumps([content or "", norm], sort_keys=True, ensure_ascii=False) + return hashlib.sha1(blob.encode("utf-8")).hexdigest() + + @staticmethod + def _split_on_sentinels(rendered: str, sub: dict[str, list[int]]) -> list[dict]: + """Split `rendered` on the sentinels into alternating {"text"} chunks and + {"ids"} runs (each sentinel -> its stored id list).""" + pattern = re.compile("|".join(re.escape(s) for s in sub)) + segments: list[dict] = [] + pos = 0 + for mobj in pattern.finditer(rendered): + if mobj.start() > pos: + segments.append({"text": rendered[pos : mobj.start()]}) + segments.append({"ids": sub[mobj.group()]}) + pos = mobj.end() + if pos < len(rendered): + segments.append({"text": rendered[pos:]}) + return segments + + def build_prompt_input( + self, + *, + session_id: Optional[str], + messages: list[ChatMessage], + rendered_prompt: str, + tools, + template_kwargs, + ) -> PromptInput: + """Return a PromptInput: token-ID segments when this session has faithful + stored ids for matching prior assistant turns, else the plain rendered + text. Each incoming assistant turn is matched IN ORDER against the stored + records and only spliced when (a) its fingerprint matches what we returned + (else the history diverged -> stop, splice nothing further) and (b) we + kept faithful ids for it (a stop-trimmed turn's None -> rendered as text). + Falls back to text on a sentinel collision or a render that + dropped/duplicated a sentinel.""" + stored = self._turns.get(session_id or "") + if not stored: + return PromptInput(text=rendered_prompt) + positions = [i for i, m in enumerate(messages) if m.role == "assistant"] + splice: dict[int, list[int]] = {} # message index -> exact ids + diverged_at = None + for k, pos in enumerate(positions): + if k >= len(stored): + break + m = messages[pos] + if self._assistant_fingerprint(m.content, m.tool_calls) != stored[k]["fp"]: + diverged_at = k # this stored turn and every later one are stale + break + if stored[k]["ids"] is not None: + splice[pos] = stored[k]["ids"] + if diverged_at is not None: + # Drop the stale tail from the first mismatch so an edited/branched + # earlier turn can't keep shadowing future requests; the matched + # prefix [:diverged_at] is untouched and still splices. We have no + # exact ids for the edited turn itself (the client authored it, we + # didn't generate it), so warm resume for that turn and the ones after + # it stays text until the session is reset/closed. Safe regardless: + # stale ids are never spliced and the worker's exact-token prefix + # check backstops correctness. + del stored[diverged_at:] + if not splice: + return PromptInput(text=rendered_prompt) + token = uuid.uuid4().hex + sentinel_at = {pos: f"<>" for j, pos in enumerate(splice)} + sub = {sentinel_at[pos]: ids for pos, ids in splice.items()} + # A sentinel must not already occur in the rendered output. + if any(s in rendered_prompt for s in sub): + return PromptInput(text=rendered_prompt) + modified = [ + ( + ChatMessage(role="assistant", content=sentinel_at[i]) + if i in sentinel_at + else m + ) + for i, m in enumerate(messages) + ] + rendered = self._template.render( + modified, tools=tools, template_kwargs=template_kwargs + ) + # Each sentinel must survive templating exactly once, else fall back. + if any(rendered.count(s) != 1 for s in sub): + return PromptInput(text=rendered_prompt) + return PromptInput(segments=self._split_on_sentinels(rendered, sub)) + + def record_assistant_turn( + self, + *, + session_id: Optional[str], + content, + tool_calls, + generated_token_ids: list, + prior_turns: int, + ) -> None: + """Record this turn's {fingerprint, exact generated ids} at position + `prior_turns` -- the count of assistant turns in the request this + response answers. Stored records at/after that index are dropped first, so + a regenerated or branched turn under the same session_id replaces stale + records instead of leaving them to shadow future warm-resume hits with a + stale fingerprint. ids is None when the worker omitted them (stop-trimmed + turn) -- recorded as non-resumable but kept for positional alignment.""" + if not session_id: + return + turns = self._turns.setdefault(session_id, []) + del turns[prior_turns:] + turns.append( + { + "fp": self._assistant_fingerprint(content, tool_calls), + "ids": list(generated_token_ids) if generated_token_ids else None, + } + ) + + def reset(self, session_id: str) -> None: + self._turns.pop(session_id, None) + + def close(self, session_id: str) -> None: + self._turns.pop(session_id, None) diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 3b552228980..da83990fe58 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -22,6 +22,7 @@ InvalidSessionId, SessionCapacity, ) +from .openai_transcript import OpenAITranscriptState from .protocol import ( _new_id, ChatCompletionChunk, @@ -73,6 +74,10 @@ def __init__( # Special tokens (e.g. <|im_end|>) the runner decodes to text; we cut the # visible content at the first one so they don't leak into responses. self._stops = template.special_tokens() + # OpenAI/chat-template token-ID warm-resume state (V2b.1.5). Adapter-side, + # not runtime; kept in lockstep with the worker's session state by + # clearing both on reset/close. + self._transcript = OpenAITranscriptState(template) @staticmethod def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]: @@ -217,18 +222,24 @@ async def _preflight_session(self, session_id: str) -> None: raise GenerationError(str(e)) async def close_session(self, session_id: str) -> None: + # Lockstep: do the fallible worker op FIRST, then clear the (best-effort, + # can't-fail) transcript. If the worker op fails both retain old state, + # so they never drift. self._validate_session_id(session_id) try: await self._runtime.close(session_id) except WorkerError as e: raise GenerationError(str(e)) + self._transcript.close(session_id) async def reset_session(self, session_id: str) -> None: + # Lockstep: worker op first (fallible), then clear the transcript. self._validate_session_id(session_id) try: await self._runtime.reset(session_id) except WorkerError as e: raise GenerationError(str(e)) + self._transcript.reset(session_id) def _finish_reason( self, @@ -321,6 +332,24 @@ def _reject_unsupported_params(req: ChatCompletionRequest) -> None: "unsupported_parameter", ) + def _count_prompt_tokens(self, prompt: PromptInput) -> Optional[int]: + """Token count of what the worker will actually assemble: the rendered + text, or for token-ID segments sum(len(ids)) for {ids} runs + the + tokenized length of {text} chunks. None when no tokenizer is available to + count text (the worker still enforces the real context limit).""" + if prompt.text is not None: + return self._template.count_tokens(prompt.text) + total = 0 + for seg in prompt.segments: + if "ids" in seg: + total += len(seg["ids"]) + else: + c = self._template.count_tokens(seg["text"]) + if c is None: + return None + total += c + return total + async def create(self, req: ChatCompletionRequest): self._reject_invalid_values(req) self._reject_unsupported_params(req) @@ -333,10 +362,24 @@ async def create(self, req: ChatCompletionRequest): prompt = self._template.render( req.messages, tools=template_tools, template_kwargs=req.chat_template_kwargs ) - # Pre-flight context check: reject cleanly instead of failing mid-generation - # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer). + # Build the prompt input first: token-ID segments (V2b.1.5) splice this + # session's prior assistant turns' exact ids so warm resume stays exact + # across the chat template's lossy re-render of tool-call turns; plain + # rendered text when there's nothing to splice / on any ambiguity (the + # worker verifies the exact-token prefix regardless). + prompt_input = self._transcript.build_prompt_input( + session_id=req.session_id, + messages=req.messages, + rendered_prompt=prompt, + tools=template_tools, + template_kwargs=req.chat_template_kwargs, + ) + # Pre-flight context check against the tokens the worker will actually + # assemble: for segments that is sum(len(ids)) + tokenized text, not the + # rendered string, so a near-limit prompt agrees with the worker rather + # than false-400ing or failing mid-decode. Only when a tokenizer exists. if self._max_context: - count = self._template.count_tokens(prompt) + count = self._count_prompt_tokens(prompt_input) if count is not None: if count >= self._max_context: raise ContextLengthExceeded(count, self._max_context) @@ -347,7 +390,6 @@ async def create(self, req: ChatCompletionRequest): if requested > 0 and count + requested > self._max_context: raise ContextLengthExceeded(count, self._max_context, requested) options = self._options(req) - prompt_input = PromptInput(text=prompt) # Admit the session up front (before the stream's first chunk) so a # capacity refusal is an HTTP status, not a mid-stream error event. if req.session_id is not None: @@ -375,6 +417,16 @@ async def _complete( # Bound the raw output at the first stop/special token BEFORE tool # parsing, so a call after the stop boundary is not parsed/emitted. tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req)) + # Record after the response is finalized: the fingerprint is of exactly + # what we return (content + tool_calls), so the next turn can confirm the + # client echoed this turn before splicing its ids. + self._transcript.record_assistant_turn( + session_id=req.session_id, + content=content, + tool_calls=tool_calls, + generated_token_ids=stats.generated_token_ids, + prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + ) finish = self._finish_reason( req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason ) @@ -439,6 +491,7 @@ def on_stop(): stop_hit[0] = True self._runtime.stop() + streamed: list[str] = [] async for token in self._clean( self._runtime.generate_stream( req.session_id, prompt, options, stats @@ -446,7 +499,9 @@ def on_stop(): stops, on_stop=on_stop, ): + streamed.append(token) yield chunk(DeltaMessage(content=token)) + content = "".join(streamed) # for the session fingerprint except ( Exception ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket @@ -461,6 +516,13 @@ def on_stop(): yield f"data: {json.dumps({'error': err})}\n\n" yield "data: [DONE]\n\n" return + self._transcript.record_assistant_turn( + session_id=req.session_id, + content=content, + tool_calls=tool_calls, + generated_token_ids=stats.generated_token_ids, + prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + ) if use_tools: if content: diff --git a/extension/llm/server/python/tests/test_sessions.py b/extension/llm/server/python/tests/test_sessions.py index c8a9cda3c57..6324edff740 100644 --- a/extension/llm/server/python/tests/test_sessions.py +++ b/extension/llm/server/python/tests/test_sessions.py @@ -139,6 +139,96 @@ def test_session_header_precedence(make_client): assert fake.opened_log == ["xet"] +def _chat_msgs(client, messages, session_id): + return client.post( + "/v1/chat/completions", + json={"model": "test-model", "session_id": session_id, "messages": messages}, + ) + + +# The fake worker streams tokens ("Hello", ", ", "world"), so the assistant +# content we return (and the client must echo back to match the fingerprint) is: +_FAKE_REPLY = "Hello, world" + + +def test_token_id_segments_splice_prior_assistant_turn(make_client): + # V2b.1.5: the server stores turn-1's generated ids and, on turn 2, sends + # prompt_segments that splice them back as an exact {ids} run (not text) -- + # but only because the client echoes back the assistant turn we generated. + client, fake = make_client(max_named_sessions=2, gen_ids=[7, 8, 9]) + assert ( + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s").status_code == 200 + ) + # First turn has no prior assistant turn -> plain text prompt. + assert fake.captured_config.prompt_segments is None + + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": _FAKE_REPLY}, # matches what we returned + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + segs = fake.captured_config.prompt_segments + assert segs is not None, "expected token-ID segments on the second turn" + # The stored generated ids are spliced in as an exact id run... + assert any(s.get("ids") == [7, 8, 9] for s in segs) + # ...bracketed by text segments (template glue + the new user turn). + assert any("text" in s for s in segs) + + +def test_edited_assistant_turn_not_spliced(make_client): + # P1 guard: if the client edits a prior assistant turn (or reuses the session + # for a different conversation), the stale ids must NOT be spliced -- the + # fingerprint mismatches and we fall back to text. + client, fake = make_client(max_named_sessions=2, gen_ids=[7, 8, 9]) + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "EDITED - not what the model generated"}, + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + assert fake.captured_config.prompt_segments is None + + +def test_stop_trimmed_turn_not_spliced(make_client): + # P1/P2 guard: a stop-trimmed turn (worker omits generated_token_ids -> + # recorded ids=None) is never spliced, even when the turn fingerprint matches, + # so unseen post-stop tokens can't be injected into a later prompt. + client, fake = make_client(max_named_sessions=2, gen_ids=[]) # [] => ids None + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": _FAKE_REPLY}, # fingerprint matches + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + assert fake.captured_config.prompt_segments is None + + +def test_no_segments_for_anonymous_requests(make_client): + client, fake = make_client(max_named_sessions=2, gen_ids=[1, 2]) + client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert fake.captured_config.prompt_segments is None + + +def test_reset_clears_stored_ids(make_client): + # After reset, the next turn has no stored ids to splice -> plain text again. + client, fake = make_client(max_named_sessions=2, gen_ids=[5, 6]) + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + client.post("/v1/sessions/s/reset") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "more"}, + ] + _chat_msgs(client, turn2, "s") + assert fake.captured_config.prompt_segments is None + + def test_reset_endpoint_clears_context_but_keeps_slot(make_client): # max_named=1: open "a", reset it, then a *different* id must still 429 — # proving reset cleared context without freeing the slot (unlike DELETE). @@ -156,3 +246,116 @@ def test_reset_invalid_session_id_rejected(make_client): r = client.post("/v1/sessions/has%20space/reset") assert r.status_code == 400 assert r.json()["error"]["code"] == "invalid_session_id" + + +class _RaisingRuntime: + """Runtime whose worker ops fail, to exercise the lockstep invariant.""" + + async def open(self, sid): + pass + + async def reset(self, sid): + raise WorkerError("worker down") + + async def close(self, sid): + raise WorkerError("worker down") + + +@pytest.mark.parametrize("op", ["reset_session", "close_session"]) +def test_worker_op_failure_keeps_transcript(op): + # Lockstep invariant: if the worker reset/close fails, the adapter transcript + # must NOT be cleared -- both retain old state so they never drift. + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + serving = ServingChat(_RaisingRuntime(), template, "test-model") + serving._transcript.record_assistant_turn( + session_id="s", + content="hi", + tool_calls=None, + generated_token_ids=[1, 2], + prior_turns=0, + ) + + async def go(): + with pytest.raises(GenerationError): + await getattr(serving, op)("s") + + asyncio.run(go()) + assert serving._transcript._turns.get( + "s" + ), "transcript cleared despite worker failure" + + +def test_record_assistant_turn_replaces_stale_at_position(): + # A regenerated/branched turn under the same session_id must REPLACE the + # record at its position (prior_turns), not append, so a later turn can still + # splice the regenerated ids instead of breaking on a stale fingerprint. + from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, + ) + + t = OpenAITranscriptState(ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)) + t.record_assistant_turn( + session_id="s", + content="a0", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + ) + t.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + ) + assert [r["ids"] for r in t._turns["s"]] == [[1], [2]] + # regenerate turn 2 (same prior_turns) -> replaces stale [2], no stale tail + t.record_assistant_turn( + session_id="s", + content="a1b", + tool_calls=None, + generated_token_ids=[3], + prior_turns=1, + ) + assert [r["ids"] for r in t._turns["s"]] == [[1], [3]] + + +def test_divergence_truncates_stale_tail(): + # Editing an EARLIER assistant turn (divergence at k) prunes the stale tail + # from k so it can't keep shadowing future requests; nothing is spliced and + # the matched prefix is kept. (Restoring hits for the edited turn isn't + # possible -- we never generated its ids -- but staleness is bounded.) + from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, + ) + from executorch.extension.llm.server.python.protocol import ChatMessage + + t = OpenAITranscriptState(ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)) + t.record_assistant_turn( + session_id="s", + content="a0", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + ) + t.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + ) + msgs = [ + ChatMessage(role="user", content="u0"), + ChatMessage(role="assistant", content="a0-EDITED"), + ChatMessage(role="user", content="u1"), + ] + out = t.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt="X", + tools=None, + template_kwargs=None, + ) + assert out.text == "X" # diverged -> plain text fallback + assert t._turns["s"] == [] # stale tail pruned from the first mismatch From a120263f4fab713272df4f6bcb38db91748e2fe4 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 9 Jun 2026 15:43:23 -0700 Subject: [PATCH 2/2] [UPDATE] Update [ghstack-poisoned] --- extension/llm/server/python/chat_template.py | 33 ++ .../llm/server/python/openai_transcript.py | 112 +++- extension/llm/server/python/serving_chat.py | 13 +- .../python/tests/test_warm_resume_scaffold.py | 509 ++++++++++++++++++ 4 files changed, 647 insertions(+), 20 deletions(-) create mode 100644 extension/llm/server/python/tests/test_warm_resume_scaffold.py diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py index cbb3eff80bf..7f1c8f386d9 100644 --- a/extension/llm/server/python/chat_template.py +++ b/extension/llm/server/python/chat_template.py @@ -91,6 +91,9 @@ def __init__( # Server-level defaults (e.g. {"enable_thinking": False}); per-request # chat_template_kwargs override these. self._defaults = default_template_kwargs or {} + # Cache of the (deterministic) generation scaffold per resolved mode, so + # warm-resume bookkeeping doesn't re-render a probe prompt every request. + self._preamble_cache: dict[tuple, str] = {} self._hf = None if hf_tokenizer_path: from transformers import AutoTokenizer @@ -136,6 +139,36 @@ def render( ) return self._fallback(messages) + def generation_preamble( + self, template_kwargs: Optional[dict[str, Any]] = None + ) -> str: + """The deterministic text the generation prompt appends after the final + ``<|im_start|>assistant\\n`` for this mode (Qwen3 no-think: + ``\\n\\n\\n\\n``; think: ``\\n``; ``""`` for + templates that add no scaffold). The worker prefills this into resident + KV, so warm-resume splicing must reproduce it ahead of a turn's generated + ids. Computed by rendering a trivial prompt with the same mode resolution + as :meth:`render` and taking the text after the final assistant header. + Returns ``""`` for the fallback / no-scaffold templates (fix is a no-op). + """ + if self._hf is None: + return "" + merged = {**self._defaults, **(template_kwargs or {})} + key = tuple(sorted((k, repr(v)) for k, v in merged.items())) + cached = self._preamble_cache.get(key) + if cached is not None: + return cached + rendered = self.render( + [ChatMessage(role="user", content="")], + tools=None, + template_kwargs=template_kwargs, + ) + marker = "<|im_start|>assistant\n" + idx = rendered.rfind(marker) + preamble = rendered[idx + len(marker) :] if idx != -1 else "" + self._preamble_cache[key] = preamble + return preamble + def chat_template_str(self) -> Optional[str]: """Raw chat-template string (for tool-format auto-detection), if available.""" return ( diff --git a/extension/llm/server/python/openai_transcript.py b/extension/llm/server/python/openai_transcript.py index 10e305fdcae..2eaff5fd7f0 100644 --- a/extension/llm/server/python/openai_transcript.py +++ b/extension/llm/server/python/openai_transcript.py @@ -34,6 +34,30 @@ from .protocol import ChatMessage from .session_runtime import PromptInput +# The assistant header that precedes a turn's generation scaffold + content. +_ASSIST_HDR = "<|im_start|>assistant\n" +# A scaffold region is exactly empty (history strips it before the last user) or +# one of the Qwen3 think scaffolds (history preserves the empty block after the +# last user; the open form is the think-mode generation preamble). Anything else +# in that region is unrecognized -> the splice falls back to plain text. +_THINK_SCAFFOLD_RE = re.compile(r"\A(?:\n\n\n\n|\n)?\Z") + + +def _normalize_tool_args(args): + """OpenAI tool-call ``arguments`` are JSON strings a client may reserialize + with different whitespace or key order while preserving the same value (e.g. + a server-emitted ``{"command": "x"}`` echoed back compact as + ``{"command":"x"}``). Parse to an object so the fingerprint compares the + semantic payload, not bytes -- the outer sort_keys dump then canonicalizes + it. A non-JSON string (or already-structured args) is returned unchanged, so + it stays byte-sensitive.""" + if isinstance(args, str): + try: + return json.loads(args) + except (ValueError, TypeError): + return args + return args + class OpenAITranscriptState: def __init__(self, template: ChatTemplate): @@ -52,24 +76,64 @@ def _assistant_fingerprint(content, tool_calls) -> str: for tc in tool_calls or []: fn = getattr(tc, "function", None) if fn is not None: - norm.append([getattr(fn, "name", None), getattr(fn, "arguments", None)]) + name, args = getattr(fn, "name", None), getattr(fn, "arguments", None) elif isinstance(tc, dict): f = tc.get("function", {}) - norm.append([f.get("name"), f.get("arguments")]) + name, args = f.get("name"), f.get("arguments") + else: + continue + norm.append([name, _normalize_tool_args(args)]) blob = json.dumps([content or "", norm], sort_keys=True, ensure_ascii=False) return hashlib.sha1(blob.encode("utf-8")).hexdigest() @staticmethod - def _split_on_sentinels(rendered: str, sub: dict[str, list[int]]) -> list[dict]: + def _normalize_scaffold(text_chunk: str, preamble: str) -> Optional[str]: + """Force the scaffold region -- the text between the last assistant header + in `text_chunk` and its end -- to equal `preamble`, so the worker + re-tokenizes the exact generation scaffold it made resident for this turn. + The region (the content was replaced by a sentinel) is empty when history + stripped the scaffold (insert) or a think scaffold when history preserved + it (replace, possibly with a different form than `preamble`). Returns the + adjusted text, or None if the region is not a recognized scaffold + (ambiguous -> caller falls back to plain text).""" + # No scaffold for this turn's mode/template: nothing to reproduce, so + # leave the chunk untouched -- and don't require the Qwen/ChatML header, + # so token-id splicing still works for templates with a different + # assistant header (the fix stays a true no-op for non-think models). + if not preamble: + return text_chunk + h = text_chunk.rfind(_ASSIST_HDR) + if h == -1: + return None + base = h + len(_ASSIST_HDR) + region = text_chunk[base:] + if region == preamble: + return text_chunk + if not _THINK_SCAFFOLD_RE.match(region): + return None + return text_chunk[:base] + preamble + + @staticmethod + def _split_on_sentinels( + rendered: str, sub: dict[str, dict] + ) -> Optional[list[dict]]: """Split `rendered` on the sentinels into alternating {"text"} chunks and - {"ids"} runs (each sentinel -> its stored id list).""" + {"ids"} runs (each sentinel -> sub[sentinel] = {"ids", "preamble"}). The + {text} chunk before each {ids} run has its assistant scaffold normalized + to that turn's stored preamble. Returns None if any pre-sentinel scaffold + region is ambiguous (caller falls back to plain text).""" pattern = re.compile("|".join(re.escape(s) for s in sub)) segments: list[dict] = [] pos = 0 for mobj in pattern.finditer(rendered): - if mobj.start() > pos: - segments.append({"text": rendered[pos : mobj.start()]}) - segments.append({"ids": sub[mobj.group()]}) + norm = OpenAITranscriptState._normalize_scaffold( + rendered[pos : mobj.start()], sub[mobj.group()]["preamble"] + ) + if norm is None: + return None + if norm: + segments.append({"text": norm}) + segments.append({"ids": sub[mobj.group()]["ids"]}) pos = mobj.end() if pos < len(rendered): segments.append({"text": rendered[pos:]}) @@ -104,7 +168,7 @@ def build_prompt_input( # worker exact-prefix backstop); it only lowers the warm-resume hit rate, # silently, for such conversations. positions = [i for i, m in enumerate(messages) if m.role == "assistant"] - splice: dict[int, list[int]] = {} # message index -> exact ids + splice: dict[int, dict] = {} # message index -> {"ids", "preamble"} diverged_at = None for k, pos in enumerate(positions): if k >= len(stored): @@ -114,7 +178,10 @@ def build_prompt_input( diverged_at = k # this stored turn and every later one are stale break if stored[k]["ids"] is not None: - splice[pos] = stored[k]["ids"] + splice[pos] = { + "ids": stored[k]["ids"], + "preamble": stored[k].get("preamble", ""), + } if diverged_at is not None: # Drop the stale tail from the first mismatch so an edited/branched # earlier turn can't keep shadowing future requests; the matched @@ -129,7 +196,7 @@ def build_prompt_input( return PromptInput(text=rendered_prompt) token = uuid.uuid4().hex sentinel_at = {pos: f"<>" for j, pos in enumerate(splice)} - sub = {sentinel_at[pos]: ids for pos, ids in splice.items()} + sub = {sentinel_at[pos]: splice[pos] for pos in splice} # A sentinel must not already occur in the rendered output. if any(s in rendered_prompt for s in sub): return PromptInput(text=rendered_prompt) @@ -147,7 +214,11 @@ def build_prompt_input( # Each sentinel must survive templating exactly once, else fall back. if any(rendered.count(s) != 1 for s in sub): return PromptInput(text=rendered_prompt) - return PromptInput(segments=self._split_on_sentinels(rendered, sub)) + # Splice ids and normalize each turn's scaffold; None => ambiguous region. + segments = self._split_on_sentinels(rendered, sub) + if segments is None: + return PromptInput(text=rendered_prompt) + return PromptInput(segments=segments) def record_assistant_turn( self, @@ -157,14 +228,18 @@ def record_assistant_turn( tool_calls, generated_token_ids: list, prior_turns: int, + preamble: str = "", ) -> None: - """Record this turn's {fingerprint, exact generated ids} at position - `prior_turns` -- the count of assistant turns in the request this - response answers. Stored records at/after that index are dropped first, so - a regenerated or branched turn under the same session_id replaces stale - records instead of leaving them to shadow future warm-resume hits with a - stale fingerprint. ids is None when the worker omitted them (stop-trimmed - turn) -- recorded as non-resumable but kept for positional alignment.""" + """Record this turn's {fingerprint, exact generated ids, generation + preamble} at position `prior_turns` -- the count of assistant turns in the + request this response answers. Stored records at/after that index are + dropped first, so a regenerated or branched turn under the same session_id + replaces stale records instead of leaving them to shadow future + warm-resume hits with a stale fingerprint. ids is None when the worker + omitted them (stop-trimmed turn) -- recorded as non-resumable but kept for + positional alignment. `preamble` is the generation scaffold resident ahead + of these ids (mode-specific, e.g. Qwen3 `` block), reproduced ahead + of the spliced ids on the next request so the prefix stays exact.""" if not session_id: return turns = self._turns.setdefault(session_id, []) @@ -173,6 +248,7 @@ def record_assistant_turn( { "fp": self._assistant_fingerprint(content, tool_calls), "ids": list(generated_token_ids) if generated_token_ids else None, + "preamble": preamble, } ) diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 5e0ed49f51a..1017a35dce1 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -399,19 +399,25 @@ async def create(self, req: ChatCompletionRequest): if requested > 0 and count + requested > self._max_context: raise ContextLengthExceeded(count, self._max_context, requested) options = self._options(req) + # The generation scaffold the worker will prefill ahead of this turn's + # tokens (e.g. Qwen3 block), resolved with the same per-request + # mode as the render; recorded per turn so warm-resume splicing reproduces + # the exact resident scaffold even if the mode changes between requests. + preamble = self._template.generation_preamble(req.chat_template_kwargs) # Admit the session up front (before the stream's first chunk) so a # capacity refusal is an HTTP status, not a mid-stream error event. if req.session_id is not None: await self._preflight_session(req.session_id) if req.stream: - return self._stream(req, prompt_input, options) - return await self._complete(req, prompt_input, options) + return self._stream(req, prompt_input, options, preamble) + return await self._complete(req, prompt_input, options, preamble) async def _complete( self, req: ChatCompletionRequest, prompt: PromptInput, options: GenerationOptions, + preamble: str = "", ) -> ChatCompletionResponse: stats = GenStats() try: @@ -435,6 +441,7 @@ async def _complete( tool_calls=tool_calls, generated_token_ids=stats.generated_token_ids, prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + preamble=preamble, ) finish = self._finish_reason( req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason @@ -459,6 +466,7 @@ async def _stream( req: ChatCompletionRequest, prompt: PromptInput, options: GenerationOptions, + preamble: str = "", ) -> AsyncIterator[str]: cid = _new_id("chatcmpl") @@ -531,6 +539,7 @@ def on_stop(): tool_calls=tool_calls, generated_token_ids=stats.generated_token_ids, prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + preamble=preamble, ) if use_tools: diff --git a/extension/llm/server/python/tests/test_warm_resume_scaffold.py b/extension/llm/server/python/tests/test_warm_resume_scaffold.py new file mode 100644 index 00000000000..f58b481842e --- /dev/null +++ b/extension/llm/server/python/tests/test_warm_resume_scaffold.py @@ -0,0 +1,509 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Warm-resume generation-scaffold reproduction (V2b.1.5). + +Qwen3's template prefills a deterministic ```` scaffold into the +generation prompt (so it lands in resident KV) but strips it when re-rendering a +turn as history *before* the last user message, while *preserving* it (as the +empty block) for turns after. The token-ID splice must reproduce each turn's +exact resident scaffold ahead of its generated ids, normalizing whatever the +history render put there -- inserting when stripped, replacing when a different +form was preserved -- so the worker's exact-token prefix check lands. +""" + +import os + +import pytest + +from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, +) +from executorch.extension.llm.server.python.protocol import ( + ChatMessage, + FunctionCall, + ToolCall, +) + +HDR = "<|im_start|>assistant\n" +NOTHINK = "\n\n\n\n" # no-think generation preamble / preserved block +THINK = "\n" # think-mode generation preamble + + +def _msgs(*pairs): + return [ChatMessage(role=r, content=c) for r, c in pairs] + + +class _FakeQwen: + """Mimics Qwen3 scaffold behavior in render(): the generation prompt appends + the mode scaffold after the assistant header; history strips the scaffold for + assistant turns before the last user message and preserves the empty block + for turns after it (true in both modes -- the case that needs normalize).""" + + def __init__(self, default_thinking=False): + self._default_thinking = default_thinking + + def _gen(self, kw): + thinking = (kw or {}).get("enable_thinking", self._default_thinking) + return THINK if thinking else NOTHINK + + def render(self, messages, tools=None, template_kwargs=None): + last_user = max( + (i for i, m in enumerate(messages) if m.role == "user"), default=-1 + ) + out = [] + for i, m in enumerate(messages): + c = m.content if isinstance(m.content, str) else "" + if m.role == "assistant" and i > last_user: + out.append(f"{HDR}{NOTHINK}{c}<|im_end|>\n") # preserved empty block + else: + out.append(f"<|im_start|>{m.role}\n{c}<|im_end|>\n") + out.append(HDR + self._gen(template_kwargs)) + return "".join(out) + + +class _FakePlain: + """No-scaffold ChatML template (preamble '').""" + + def render(self, messages, tools=None, template_kwargs=None): + out = [ + f"<|im_start|>{m.role}\n" + f"{m.content if isinstance(m.content, str) else ''}<|im_end|>\n" + for m in messages + ] + out.append(HDR) + return "".join(out) + + +class _FakeOtherHeader: + """No-scaffold template whose assistant header is NOT the Qwen/ChatML one + (Llama-style), to prove token-id splicing isn't disabled for templates that + don't use ``<|im_start|>assistant\\n`` when the preamble is ''.""" + + OHDR = "<|start_header_id|>assistant<|end_header_id|>\n\n" + + def render(self, messages, tools=None, template_kwargs=None): + out = [] + for m in messages: + c = m.content if isinstance(m.content, str) else "" + if m.role == "assistant": + out.append(f"{self.OHDR}{c}<|eot_id|>") + else: + out.append( + f"<|start_header_id|>{m.role}<|end_header_id|>\n\n{c}<|eot_id|>" + ) + out.append(self.OHDR) + return "".join(out) + + +def _ids_index(segs, ids): + for i, s in enumerate(segs): + if s.get("ids") == ids: + return i + return -1 + + +def _text_before_ids(segs, ids): + i = _ids_index(segs, ids) + assert i > 0 and "text" in segs[i - 1], "expected a {text} segment before {ids}" + return segs[i - 1]["text"] + + +def _scaffold_before(segs, ids): + """The scaffold region: text after the last assistant header preceding ids.""" + return _text_before_ids(segs, ids).rsplit(HDR, 1)[-1] + + +# --- 5a. Hermetic unit tests (no model) ------------------------------------- + + +def test_nothink_ordinary_append_inserts_scaffold(): + st = OpenAITranscriptState(_FakeQwen(default_thinking=False)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[10, 11, 12], + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": False} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + # History stripped the scaffold; the fix inserts exactly one copy. + assert _scaffold_before(pi.segments, [10, 11, 12]) == NOTHINK + assert _text_before_ids(pi.segments, [10, 11, 12]).count(NOTHINK) == 1 + + +def test_think_ordinary_append_inserts_open_scaffold(): + st = OpenAITranscriptState(_FakeQwen(default_thinking=True)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[1, 2], + prior_turns=0, + preamble=THINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [1, 2]) == THINK + + +def test_think_toolloop_normalizes_preserved_scaffold(): + # Turn generated in THINK mode (preamble open-think) but rendered as a + # post-last-user turn, where history preserves the *empty* block. The fix + # must REPLACE that block with the stored open-think preamble -- not keep it + # (wrong scaffold) and not append a second one (double-insert). + st = OpenAITranscriptState(_FakeQwen(default_thinking=True)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[7, 8, 9], + prior_turns=0, + preamble=THINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1")) # a1 AFTER last user + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [7, 8, 9]) == THINK + # the preserved empty block was replaced, not kept and not doubled + assert NOTHINK not in _text_before_ids(pi.segments, [7, 8, 9]) + + +def test_no_scaffold_template_is_unchanged(): + st = OpenAITranscriptState(_FakePlain()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[5], + prior_turns=0, + preamble="", + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [5]) == "" # nothing inserted, no regression + + +def test_non_qwen_header_no_scaffold_still_splices(): + # Regression: a no-scaffold template whose assistant header isn't the + # Qwen/ChatML one must still get token-id splicing (the normalization is a + # no-op when preamble == "", not a hard requirement for the Qwen header). + st = OpenAITranscriptState(_FakeOtherHeader()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[9, 9], + prior_turns=0, + preamble="", + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None # splicing NOT disabled by the missing header + assert any(s.get("ids") == [9, 9] for s in pi.segments) # ids actually spliced + + +def test_stop_trimmed_turn_falls_back_to_text(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[], # stop-trimmed -> ids None -> not resumable + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": False} + rendered = st._template.render(msgs, template_kwargs=kw) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=None, + template_kwargs=kw, + ) + assert pi.segments is None and pi.text == rendered + + +def test_fingerprint_mismatch_falls_back_to_text(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[10], + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "EDITED"), ("user", "u2")) + kw = {"enable_thinking": False} + rendered = st._template.render(msgs, template_kwargs=kw) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=None, + template_kwargs=kw, + ) + assert pi.segments is None and pi.text == rendered + + +def test_mode_switch_uses_per_turn_scaffold(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + preamble=NOTHINK, # turn 1 generated no-think + ) + st.record_assistant_turn( + session_id="s", + content="a2", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + preamble=THINK, # turn 2 generated think + ) + msgs = _msgs( + ("user", "u1"), + ("assistant", "a1"), + ("user", "u2"), + ("assistant", "a2"), + ("user", "u3"), + ) + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [1]) == NOTHINK + assert _scaffold_before(pi.segments, [2]) == THINK + + +# --- Tool-call argument fingerprint canonicalization ------------------------ + + +def _fp(content, tool_calls): + return OpenAITranscriptState._assistant_fingerprint(content, tool_calls) + + +def _dtc(name, args): + return {"function": {"name": name, "arguments": args}} + + +def test_fingerprint_ignores_tool_arg_whitespace(): + assert _fp(None, [_dtc("bash", '{"command": "echo hi"}')]) == _fp( + None, [_dtc("bash", '{"command":"echo hi"}')] + ) + + +def test_fingerprint_ignores_tool_arg_key_order(): + assert _fp(None, [_dtc("f", '{"x": 1, "y": 2}')]) == _fp( + None, [_dtc("f", '{"y": 2, "x": 1}')] + ) + + +def test_fingerprint_invalid_json_args_stay_byte_sensitive(): + # Non-JSON arguments can't be canonicalized, so they stay literal: a + # genuinely different string remains a different turn. + assert _fp(None, [_dtc("f", "not json {")]) != _fp( + None, [_dtc("f", "not json { ")] + ) + + +def test_fingerprint_non_string_args_match_equivalent_json_string(): + # Already-structured args hash stably and match the equivalent JSON string. + assert _fp(None, [_dtc("f", {"x": 1})]) == _fp(None, [_dtc("f", '{"x": 1}')]) + + +def test_tool_turn_splices_despite_reserialized_args(): + # End-to-end: the server recorded a spaced arguments string; the client echoes + # the same call back compact (the real pi behavior). The turn must still + # fingerprint-match and splice -- not prune to a text fallback. + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="bash", arguments='{"command": "echo hi"}'), + ) + ], + generated_token_ids=[1, 2, 3], + prior_turns=0, + preamble=NOTHINK, + ) + echoed = ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="bash", arguments='{"command":"echo hi"}'), + ) + ], + ) + msgs = [ + ChatMessage(role="user", content="u1"), + echoed, + ChatMessage(role="user", content="u2"), + ] + kw = {"enable_thinking": False} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None # matched + spliced, not pruned to text + assert any(s.get("ids") == [1, 2, 3] for s in pi.segments) + + +# --- 5b. Token-level fidelity against the real tokenizer (gated/skipped) ----- + +_MODEL = os.environ.get( + "QWEN_HF_DIR", "/home/mnachin/local/scripts/models/Qwen3.5-35B-A3B-HQQ-INT4" +) +_HAVE_MODEL = os.path.isdir(_MODEL) +_skip = pytest.mark.skipif( + not _HAVE_MODEL, reason=f"real Qwen tokenizer dir not present: {_MODEL}" +) + + +def _real_template_and_enc(): + pytest.importorskip("transformers") + from executorch.extension.llm.server.python.chat_template import ChatTemplate + from transformers import AutoTokenizer + + tmpl = ChatTemplate(hf_tokenizer_path=_MODEL) + tok = AutoTokenizer.from_pretrained(_MODEL) + # Encode the way the worker does: no extra special tokens (the rendered text + # already contains the literal <|im_*|> / control strings). + return tmpl, (lambda s: tok.encode(s, add_special_tokens=False)) + + +def _assemble(segs, enc): + out = [] + for seg in segs: + out += seg["ids"] if "ids" in seg else enc(seg["text"]) + return out + + +@_skip +@pytest.mark.parametrize("thinking", [False, True]) +def test_token_level_exact_prefix_ordinary(thinking): + tmpl, enc = _real_template_and_enc() + kw = {"enable_thinking": thinking} + st = OpenAITranscriptState(tmpl) + content = "Mercury, Venus, Earth." + gen_ids = enc(content) # stand-in for the worker's generated_token_ids + gen_prompt1 = tmpl.render(_msgs(("user", "u1")), template_kwargs=kw) + resident = enc(gen_prompt1) + gen_ids + st.record_assistant_turn( + session_id="s", + content=content, + tool_calls=None, + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(kw), + ) + msgs = _msgs(("user", "u1"), ("assistant", content), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=tmpl.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assembled = _assemble(pi.segments, enc) + # resident is an exact token prefix => plan_prefill returns exact_prefix and + # reuses exactly len(resident) tokens. + assert assembled[: len(resident)] == resident + + +@_skip +def test_token_level_exact_prefix_toolloop_think(): + # Mandatory: post-last-user turn where the template preserves a think block + # before the sentinel; the fix must normalize it to the stored open-think + # preamble so the token prefix still lands. + tmpl, enc = _real_template_and_enc() + kw = {"enable_thinking": True} + st = OpenAITranscriptState(tmpl) + content = "result is 42" + gen_ids = enc(content) + gen_prompt1 = tmpl.render(_msgs(("user", "u1")), template_kwargs=kw) + resident = enc(gen_prompt1) + gen_ids + st.record_assistant_turn( + session_id="s", + content=content, + tool_calls=None, + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(kw), + ) + msgs = _msgs(("user", "u1"), ("assistant", content)) # a1 AFTER last user + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=tmpl.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assembled = _assemble(pi.segments, enc) + assert assembled[: len(resident)] == resident