From b930b9f991cd3bab7c7412989e3a2f42a561f27d Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 9 Jun 2026 09:23:09 -0700 Subject: [PATCH] [INITIAL] Update [ghstack-poisoned] --- examples/models/qwen3_5_moe/README.md | 28 ++- .../models/qwen3_5_moe/qwen35_moe_worker.cpp | 19 +- examples/models/qwen3_5_moe/serve.py | 15 +- examples/models/qwen3_5_moe/test_serve.py | 3 + extension/llm/server/cpp/CMakeLists.txt | 9 + .../server/cpp/test_worker_prefill_plan.cpp | 120 ++++++++++ extension/llm/server/cpp/worker_loop.h | 213 +++++++++++++----- .../llm/server/cpp/worker_prefill_plan.h | 72 ++++++ extension/llm/server/python/server.py | 11 + extension/llm/server/python/serving_chat.py | 7 + .../llm/server/python/tests/test_sessions.py | 19 ++ .../server/python/tests/test_worker_client.py | 31 +++ 12 files changed, 480 insertions(+), 67 deletions(-) create mode 100644 extension/llm/server/cpp/test_worker_prefill_plan.cpp create mode 100644 extension/llm/server/cpp/worker_prefill_plan.h diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index d13961637d5..0583765cb77 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -197,6 +197,7 @@ is safe under asyncio. | `--max-context` | (none) | Reject prompts that exceed it with 400 | | `--no-think` | off | Default reasoning off (`enable_thinking=False`) | | `--max-sessions` | `1` | Isolated sessions on one weight load (see Sessions) | +| `--warm-resume` / `--no-warm-resume` | on | Reuse a session's KV across turns (see Sessions) | ### Sessions @@ -211,24 +212,41 @@ aliases, the `X-ExecuTorch-Session-ID` / `session_id` / `x-session-affinity` headers (body wins, then that header order). The header aliases let a client that already emits a stable per-conversation affinity id (e.g. pi's `sendSessionAffinityHeaders`) route with no extra config. Requests without any -share a transient scratch session. Free a session with `DELETE /v1/sessions/{id}`. +share a transient scratch session. ```bash curl http://127.0.0.1:8000/v1/chat/completions \ -H 'Content-Type: application/json' \ -d '{"model":"qwen3.5-moe","session_id":"alice", "messages":[{"role":"user","content":"hi"}]}' + +curl -X POST http://127.0.0.1:8000/v1/sessions/alice/reset # clear context, keep the slot +curl -X DELETE http://127.0.0.1:8000/v1/sessions/alice # free context + slot (VRAM) ``` Admission is up front: an explicit `session_id` on a single-session server returns **400** (`unsupported_session`); past capacity it returns **429** (`capacity_exhausted`) before any response bytes. -This is **isolation, not concurrency or warm resume**: execution is still +**Warm append-only resume** (on by default): when a named session's next request +is an exact-token extension of its resident context (e.g. the same conversation +plus a new turn), the worker prefills **only the new suffix** instead of +re-prefilling the whole prompt — continuing the KV/recurrent state in place. The +check is exact-token (never re-tokenized text), so it is always correct: anything +that can't be proven an exact extension (token mismatch, a stop-string trim, a +prior error) falls back to a full reset + prefill. This is **per-session** warm +append-only resume, **not** global prefix caching: there is no cross-session +prefix sharing, so a system prompt common to two different `session_id`s is +prefilled independently for each (unlike vLLM/llama.cpp global prefix reuse). +Each `done` event reports +`reused_prompt_tokens`, `prefilled_prompt_tokens`, and `session_reset_reason` +(`new`/`exact_prefix`/`dirty`/`mismatch`/`equal`) for measuring the hit rate. +`--no-warm-resume` forces a full prefill every request (for A/B comparison). + +This is **isolation + warm resume, not concurrency**: execution is still synchronous (one in-flight request; `--num-runners > 1` is rejected since more -workers would duplicate the weights), and each request resets its session — the -recurrent/conv state cannot be rewound by position (`seek()` is NotSupported), so -turn-to-turn KV reuse (append-only warm resume) is a follow-up. +workers would duplicate the weights). Fair interleaving across in-flight requests +is a follow-up. ### Other limitations diff --git a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp index c5018031716..ac2e3536a14 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp @@ -18,11 +18,11 @@ // process segfaults in the int4 matmul (validated). Here the model runs in a // plain synchronous loop in its own process, which is reliable. // -// Multi-session (isolation): the engine loads weights once and hosts multiple -// isolated sessions on that one ~18GB allocation; the shared worker loop -// (worker_loop.h) routes requests to per-session_id state, up to -// --max_sessions. Execution is still synchronous (one in-flight request); warm -// context reuse across requests is a follow-up. +// Multi-session: the engine loads weights once and hosts multiple isolated +// sessions on that one ~18GB allocation; the shared worker loop (worker_loop.h) +// routes requests to per-session_id state (up to --max_sessions) and warm- +// resumes each session's context across requests (append-only suffix prefill). +// Execution is synchronous (one in-flight request). #include @@ -41,6 +41,12 @@ DEFINE_int32( "Max physical sessions to host on the one weight allocation (CUDA " "per-session mutable rebinding). Clamped to 1 if the backend cannot " "rebind."); +DEFINE_bool( + warm_resume, + true, + "Warm append-only resume for named sessions: prefill only the suffix when a " + "request's tokens extend the session's resident context. Off resets every " + "request (useful for A/B measurement)."); namespace { namespace llm = ::executorch::extension::llm; @@ -73,5 +79,6 @@ int main(int argc, char** argv) { // ids back to text internally. The shared loop owns per-session_id state. ::tokenizers::Tokenizer* tokenizer = engine->tokenizer(); - return llm::run_worker_stdio_loop(*engine, *tokenizer, engine->metadata()); + return llm::run_worker_stdio_loop( + *engine, *tokenizer, engine->metadata(), FLAGS_warm_resume); } diff --git a/examples/models/qwen3_5_moe/serve.py b/examples/models/qwen3_5_moe/serve.py index e58ab23516b..9075ef8fe17 100644 --- a/examples/models/qwen3_5_moe/serve.py +++ b/examples/models/qwen3_5_moe/serve.py @@ -23,8 +23,10 @@ requests share a scratch session). See --max-sessions. * Execution is synchronous: one in-flight request at a time, concurrent HTTP requests queue. Sessions provide isolation, not concurrent throughput. - * No warm context reuse yet: each request resets its session (Qwen seek() is - NotSupported; append-only reuse is a follow-up). + * Warm append-only resume is on by default (--warm-resume): a named session + reuses its resident context across turns when the prompt is an exact-token + extension, including tool-call turns via token-ID prompt segments. Anonymous + (scratch) requests always reset. * The control plane only does blocking pipe I/O on its executor thread (no CUDA), which is safe under asyncio. @@ -83,6 +85,7 @@ def _spawn(args): if args.data_path: cmd += ["--data_path", args.data_path] cmd += ["--max_sessions", str(args.max_sessions)] + cmd += [f"--warm_resume={'true' if args.warm_resume else 'false'}"] logger.info("Starting Qwen worker subprocess (loads the model once)...") return spawn_worker(cmd, env=env) @@ -162,6 +165,14 @@ def main() -> None: "cannot rebind. One slot is reserved for anonymous requests, so the " "number of addressable session_ids is max-sessions - 1.", ) + p.add_argument( + "--warm-resume", + action=argparse.BooleanOptionalAction, + default=True, + help="Warm append-only resume for named sessions: a request whose tokens " + "extend the session's resident context prefills only the suffix. " + "--no-warm-resume resets every request (for A/B measurement).", + ) p.add_argument( "--worker-bin", default=None, diff --git a/examples/models/qwen3_5_moe/test_serve.py b/examples/models/qwen3_5_moe/test_serve.py index f8768ef39ce..4445e5867a1 100644 --- a/examples/models/qwen3_5_moe/test_serve.py +++ b/examples/models/qwen3_5_moe/test_serve.py @@ -76,6 +76,7 @@ def fake_spawn(cmd, env=None): tokenizer_path="t.json", data_path="d.ptd", max_sessions=4, + warm_resume=True, ) ) assert captured["cmd"] == [ @@ -88,6 +89,7 @@ def fake_spawn(cmd, env=None): "d.ptd", "--max_sessions", "4", + "--warm_resume=true", ] @@ -103,6 +105,7 @@ def test_spawn_defaults_worker_bin_and_omits_empty_data_path(monkeypatch): tokenizer_path="t.json", data_path=None, max_sessions=4, + warm_resume=True, ) ) cmd = captured["cmd"] diff --git a/extension/llm/server/cpp/CMakeLists.txt b/extension/llm/server/cpp/CMakeLists.txt index 653cf61bea8..18f62cfcd5f 100644 --- a/extension/llm/server/cpp/CMakeLists.txt +++ b/extension/llm/server/cpp/CMakeLists.txt @@ -86,3 +86,12 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(text_llm_worker) target_link_options(text_llm_worker PRIVATE "LINKER:-s") endif() + +# Pure unit test for the warm-resume prefill planner (worker_prefill_plan.h). No +# ET/model/tokenizer dependency, so it builds and runs standalone via ctest. +enable_testing() +add_executable(test_worker_prefill_plan test_worker_prefill_plan.cpp) +target_include_directories( + test_worker_prefill_plan PUBLIC ${_common_include_directories} +) +add_test(NAME worker_prefill_plan COMMAND test_worker_prefill_plan) diff --git a/extension/llm/server/cpp/test_worker_prefill_plan.cpp b/extension/llm/server/cpp/test_worker_prefill_plan.cpp new file mode 100644 index 00000000000..93ca4b00b6f --- /dev/null +++ b/extension/llm/server/cpp/test_worker_prefill_plan.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Unit tests for plan_prefill() (warm-resume decision). No model/session/ET +// runtime dependency -- the header is pure, so this compiles and runs +// standalone. Self-contained assertions (no gtest) so it has no build deps. + +#include + +#include +#include +#include +#include + +using executorch::extension::llm::plan_prefill; +using executorch::extension::llm::PrefillPlan; + +namespace { +int g_failures = 0; + +void expect( + const char* name, + const PrefillPlan& p, + PrefillPlan::Action action, + size_t suffix_start, + const char* reason) { + bool ok = p.action == action && p.suffix_start == suffix_start && + std::strcmp(p.reason, reason) == 0; + if (!ok) { + ++g_failures; + printf( + " [FAIL] %s: got action=%d suffix_start=%zu reason=%s\n", + name, + (int)p.action, + p.suffix_start, + p.reason); + } else { + printf(" [PASS] %s\n", name); + } +} +} // namespace + +int main() { + using V = std::vector; + + // First request: nothing resident -> full prefill, "new". + expect( + "new (resident empty)", + plan_prefill(V{}, V{1, 2, 3}, false), + PrefillPlan::kFull, + 0, + "new"); + + // Exact token extension -> prefill only the suffix. + expect( + "exact_prefix (suffix reuse)", + plan_prefill(V{1, 2, 3}, V{1, 2, 3, 4, 5}, false), + PrefillPlan::kSuffix, + 3, + "exact_prefix"); + + // Single-token extension still reuses. + expect( + "exact_prefix (one-token suffix)", + plan_prefill(V{1, 2, 3}, V{1, 2, 3, 4}, false), + PrefillPlan::kSuffix, + 3, + "exact_prefix"); + + // Divergent token -> mismatch, full reset. + expect( + "mismatch (divergent token)", + plan_prefill(V{1, 2, 3}, V{1, 2, 9, 4}, false), + PrefillPlan::kFull, + 0, + "mismatch"); + + // Prompt shorter than resident (rewind) -> mismatch, full reset. + expect( + "mismatch (prompt shorter)", + plan_prefill(V{1, 2, 3}, V{1, 2}, false), + PrefillPlan::kFull, + 0, + "mismatch"); + + // Dirty wins even over an otherwise-exact extension. + expect( + "dirty (overrides exact prefix)", + plan_prefill(V{1, 2, 3}, V{1, 2, 3, 4}, true), + PrefillPlan::kFull, + 0, + "dirty"); + + // Prompt identical to resident -> reset + full (no empty-suffix prefill). + expect( + "equal (prompt == resident)", + plan_prefill(V{1, 2, 3}, V{1, 2, 3}, false), + PrefillPlan::kFull, + 0, + "equal"); + + // Dirty + empty resident still resets as dirty (dirty checked first). + expect( + "dirty (empty resident)", + plan_prefill(V{}, V{1, 2}, true), + PrefillPlan::kFull, + 0, + "dirty"); + + printf( + "\n%s (%d failure(s))\n", + g_failures == 0 ? "ALL PASS" : "FAILED", + g_failures); + return g_failures == 0 ? 0 : 1; +} diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h index a7ec92b08f2..7f92e60371e 100644 --- a/extension/llm/server/cpp/worker_loop.h +++ b/extension/llm/server/cpp/worker_loop.h @@ -17,18 +17,26 @@ // V2a (isolation): the worker owns one LLMEngine (weights loaded once) and // hands out multiple isolated LLMSessions keyed by session_id, each with its // own KV/recurrent state, up to the engine's serving capacity. Execution is -// still synchronous -- one in-flight request at a time, the control plane -// serializes -- so this proves "one model, many isolated contexts without -// duplicating weights", NOT concurrent streaming. It also does NOT yet reuse -// context across requests: worker_handle_request() resets the session at the -// top of every request (warm append-only resume is a follow-up). +// synchronous -- one in-flight request at a time, the control plane serializes. +// +// V2b.1 (warm append-only resume): a named session keeps its decoded context +// across requests. On the next request the worker compares the new prompt's +// token ids against the session's resident token ids; if the resident ids are +// an exact prefix, it prefills ONLY the suffix (continuing the KV/recurrent +// state at pos>0) instead of resetting and re-prefilling the whole prompt. The +// check is exact-token (never string/retokenized text) and falls back to a full +// reset+prefill whenever exact reuse can't be proven, so it is always correct; +// the win is when the prompt is a genuine token extension of the prior turn. +// See plan_prefill(). // // Sessions: -// - Named: an explicit session_id -> LLMSession, created on first use (or via -// an `open` op), capped at max_named_sessions = capacity - 1 (the scratch -// slot is reserved). 0 when the backend can host only one session. -// - Scratch: one session for anonymous requests (no session_id), reset each -// request -- preserves the original single-session behavior. +// - Named: an explicit session_id -> session + resident token ids, created on +// first use (or via an `open` op), capped at max_named_sessions = capacity +// - 1 (the scratch slot is reserved). 0 when the backend hosts one session. +// Warm resume applies to named sessions (unless disabled). +// - Scratch: one session for anonymous requests (no session_id), reset every +// request -- distinct anonymous callers must never reuse each other's +// state. // // Protocol (one JSON object per line; matches worker_client.py): // worker -> stdout, once: {"ready": true, "max_sessions": int, @@ -38,12 +46,19 @@ // "stop": [str, ...], "session_id"?: str} // open: {"op": "open", "session_id": str} // close: {"op": "close", "session_id": str} +// reset: {"op": "reset", "session_id": str} // clear context, keep +// slot // worker -> stdout: // generate: {"token": str} * (streamed) -// {"done": true, "prompt_tokens": int, -// "completion_tokens": int, "finish_reason": "stop"|"length"} +// {"done": true, "prompt_tokens": int, "completion_tokens": +// int, +// "finish_reason": "stop"|"length", +// "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, +// "session_reset_reason": "new"|"exact_prefix"|"dirty"| +// "mismatch"|"equal"} // open: {"opened": true, "session_id": str} // close: {"closed": true, "session_id": str} +// reset: {"reset": true, "session_id": str} // error: {"error": str, "code"?: str} // code: "capacity_exhausted", // // "unsupported_session" // @@ -55,6 +70,7 @@ #include #include #include +#include #include #include @@ -80,15 +96,32 @@ inline void worker_emit(const nlohmann::json& obj) { std::cout.flush(); } -// One generation request: reset the session, encode the prompt, prefill, then -// loop decode_one() streaming complete-UTF-8 text pieces. A terminal step (EOS -// or cooperative stop) ends generation and is not emitted or counted. Throws -// std::runtime_error on failure; the caller reports it as {"error": ...}. +// A named session plus the warm-resume bookkeeping the worker maintains for it. +// Invariant (while not mid-mutation): resident_token_ids.size() == +// session->position() -- the resident ids are exactly the tokens currently in +// the session's KV/recurrent state, in order. +struct WorkerSessionState { + std::unique_ptr session; + std::vector resident_token_ids; + // Set when the resident state can no longer be trusted as an exact token + // prefix (e.g. a stop-string trimmed the emitted text mid-token, or a + // prefill/decode failed after mutating state). Forces a reset next request. + bool dirty = false; +}; + +// One generation request against a session. Encodes the prompt, chooses a +// prefill plan (warm suffix reuse for named sessions, or a full reset+prefill), +// then streams complete-UTF-8 text pieces from decode_one(). A terminal step +// (EOS or cooperative stop) ends generation and is not emitted or counted. +// Maintains st.resident_token_ids / st.dirty. Throws std::runtime_error on +// failure; the caller reports it as {"error": ...}. inline void worker_handle_request( - LLMSession& session, + WorkerSessionState& st, + bool warm, ::tokenizers::Tokenizer& tokenizer, const std::unordered_map& metadata, const nlohmann::json& req) { + LLMSession& session = *st.session; const std::string prompt = req.at("prompt").get(); int64_t max_new = req.value("max_new_tokens", static_cast(-1)); const float temperature = req.value("temperature", 0.0f); @@ -98,9 +131,6 @@ inline void worker_handle_request( const std::vector stops = req.value("stop", std::vector{}); - if (session.reset() != ::executorch::runtime::Error::Ok) { - throw std::runtime_error("session reset failed"); - } // No special tokens: the prompt is already rendered (the control plane // applied the chat template), matching the runner's own encode path. auto encode_result = tokenizer.encode(prompt, /*bos=*/0, /*eos=*/0); @@ -115,7 +145,9 @@ inline void worker_handle_request( // Bound generation to the context window: default to filling the remaining // room, and clamp an explicit max_new_tokens too, so decode never steps past - // the window (which would error mid-generation after partial output). + // the window (which would error mid-generation after partial output). The + // bound is on the FULL prompt length (= pos after prefill), regardless of how + // much is reused. const auto ctx_it = metadata.find(kMaxContextLen); if (ctx_it != metadata.end()) { const int64_t room = ctx_it->second - num_prompt; @@ -130,12 +162,37 @@ inline void worker_handle_request( max_new = 2048; } + // Decide full vs warm-suffix prefill. Anonymous (scratch) and warm-disabled + // sessions always full-prefill from a clean state. + PrefillPlan plan = warm ? plan_prefill(st.resident_token_ids, ids, st.dirty) + : PrefillPlan{PrefillPlan::kFull, 0, "new"}; + int64_t reused = 0; + std::vector to_prefill; + if (plan.action == PrefillPlan::kSuffix) { + reused = static_cast(plan.suffix_start); + to_prefill.assign(ids.begin() + plan.suffix_start, ids.end()); + } else { + if (session.reset() != ::executorch::runtime::Error::Ok) { + st.dirty = true; + throw std::runtime_error("session reset failed"); + } + st.resident_token_ids.clear(); + st.dirty = false; + to_prefill = ids; + } + const int64_t prefilled = static_cast(to_prefill.size()); + SamplingConfig sampling; sampling.temperature = temperature; - if (session.prefill_tokens(std::move(ids), &sampling) != + if (session.prefill_tokens(std::move(to_prefill), &sampling) != ::executorch::runtime::Error::Ok) { + st.dirty = true; // state may be partially mutated; force a reset next time throw std::runtime_error("prefill failed"); } + // The resident state now equals the full prompt (resident prefix + prefilled + // suffix, or the whole prompt). Keep the invariant + // resident.size()==position(). + st.resident_token_ids = ids; std::string buf; // bytes not yet forming a complete UTF-8 prefix std::string pending; // complete-UTF-8 text held back for stop-string matching @@ -145,6 +202,7 @@ inline void worker_handle_request( for (int64_t step = 0; step < max_new; ++step) { auto step_result = session.decode_one(sampling); if (step_result.error() != ::executorch::runtime::Error::Ok) { + st.dirty = true; throw std::runtime_error("decode failed"); } const auto& d = step_result.get(); @@ -152,6 +210,10 @@ inline void worker_handle_request( finish = "stop"; break; // terminal step (EOS / cooperative stop): not emitted or counted } + // The token was forwarded into the cache (pos advanced); track it so the + // resident-ids/position invariant holds. EOS/terminal tokens are not + // forwarded, so they are not appended (above). + st.resident_token_ids.push_back(d.token_id); ++num_generated; buf += d.text_piece; const size_t cut = utf8_complete_prefix_len(buf); @@ -168,6 +230,10 @@ inline void worker_handle_request( if (stop_hit) { finish = "stop"; // reached a stop string: drop it and everything after stop_string = true; + // The emitted text was trimmed at the stop string, so the next turn's + // rendered prompt won't be an exact token extension of resident: force a + // reset rather than risk a false prefix match. + st.dirty = true; break; } } @@ -181,11 +247,16 @@ inline void worker_handle_request( } // finish_reason: "stop" if the model emitted EOS or hit a stop string, else // "length" -- it ran to max_new (possibly clamped to the context window). + // reused/prefilled sum to prompt_tokens; session_reset_reason explains the + // prefill plan (for measuring warm-resume hit rate). worker_emit( {{"done", true}, {"prompt_tokens", num_prompt}, {"completion_tokens", num_generated}, - {"finish_reason", finish}}); + {"finish_reason", finish}, + {"reused_prompt_tokens", reused}, + {"prefilled_prompt_tokens", prefilled}, + {"session_reset_reason", plan.reason}}); } // Owns the engine's sessions for one worker: named sessions keyed by id plus a @@ -211,10 +282,10 @@ class WorkerSessions { // Resolve (and admit, creating on first use) a named session. Returns nullptr // and sets code on failure: "unsupported_session" when the backend hosts no // named sessions, "capacity_exhausted" when all named slots are taken. - LLMSession* open_named(const std::string& id, std::string& code) { + WorkerSessionState* open_named(const std::string& id, std::string& code) { auto it = named_.find(id); if (it != named_.end()) { - return it->second.get(); // idempotent open / reuse across requests + return &it->second; // idempotent open / reuse across requests } if (max_named_ == 0) { code = "unsupported_session"; @@ -229,9 +300,9 @@ class WorkerSessions { code = "capacity_exhausted"; // engine-side capacity backstop return nullptr; } - auto* session = result.get().get(); - named_.emplace(id, std::move(result.get())); - return session; + WorkerSessionState& st = named_[id]; + st.session = std::move(result.get()); + return &st; } // Destroy a named session (freeing its per-session state); idempotent. @@ -239,33 +310,56 @@ class WorkerSessions { named_.erase(id); } + // Clear a named session's context (reset KV/recurrent + resident ids) while + // keeping its capacity slot allocated. No-op if the session doesn't exist. + // Returns Ok (including the absent no-op); on a failed reset returns the + // session's error and leaves resident state intact, so the control plane + // keeps its transcript in lockstep instead of clearing it after a failed + // reset. + ::executorch::runtime::Error reset_named(const std::string& id) { + auto it = named_.find(id); + if (it == named_.end()) { + return ::executorch::runtime::Error::Ok; + } + auto err = it->second.session->reset(); + if (err != ::executorch::runtime::Error::Ok) { + return err; + } + it->second.resident_token_ids.clear(); + it->second.dirty = false; + return ::executorch::runtime::Error::Ok; + } + // The scratch session for anonymous requests, created on first use. Throws if // the engine cannot create it. - LLMSession* scratch() { - if (!scratch_) { + WorkerSessionState* scratch() { + if (!scratch_.session) { auto result = engine_.create_session(); if (result.error() != ::executorch::runtime::Error::Ok) { throw std::runtime_error("failed to create scratch session"); } - scratch_ = std::move(result.get()); + scratch_.session = std::move(result.get()); } - return scratch_.get(); + return &scratch_; } private: LLMEngine& engine_; int32_t max_named_; - std::unordered_map> named_; - std::unique_ptr scratch_; + std::unordered_map named_; + WorkerSessionState scratch_; }; // Emit {"ready": true, ...}, then read JSONL requests from stdin and dispatch -// each (generate / open / close), reporting exceptions as {"error": ...} and -// continuing to serve. Returns 0 when stdin closes. +// each (generate / open / close / reset), reporting exceptions as +// {"error": ...} and continuing to serve. Returns 0 when stdin closes. +// enable_warm_resume gates V2b.1 warm suffix reuse for named sessions (off -> +// every request resets, the V2a behavior; useful for A/B measurement). inline int run_worker_stdio_loop( LLMEngine& engine, ::tokenizers::Tokenizer& tokenizer, - const std::unordered_map& metadata) { + const std::unordered_map& metadata, + bool enable_warm_resume = true) { WorkerSessions sessions(engine); worker_emit( {{"ready", true}, @@ -283,7 +377,7 @@ inline int run_worker_stdio_loop( const nlohmann::json req = nlohmann::json::parse(line); const std::string op = req.value("op", std::string{}); - if (op == "open" || op == "close") { + if (op == "open" || op == "close" || op == "reset") { const std::string id = req.at("session_id").get(); if (id.empty()) { throw std::runtime_error("session_id required for op"); @@ -291,38 +385,49 @@ inline int run_worker_stdio_loop( if (op == "close") { sessions.close_named(id); worker_emit({{"closed", true}, {"session_id", id}}); - continue; - } - std::string code; - if (sessions.open_named(id, code) == nullptr) { - worker_emit( - {{"error", "cannot open session"}, - {"code", code}, - {"session_id", id}}); - } else { - worker_emit({{"opened", true}, {"session_id", id}}); + } else if (op == "reset") { + // idempotent (no-op if absent); only acks success if the reset took + if (sessions.reset_named(id) != ::executorch::runtime::Error::Ok) { + worker_emit( + {{"error", "session reset failed"}, {"session_id", id}}); + } else { + worker_emit({{"reset", true}, {"session_id", id}}); + } + } else { // open + std::string code; + if (sessions.open_named(id, code) == nullptr) { + worker_emit( + {{"error", "cannot open session"}, + {"code", code}, + {"session_id", id}}); + } else { + worker_emit({{"opened", true}, {"session_id", id}}); + } } continue; } // Generation. A session_id routes to its named session (admitted on first - // use); its absence uses the shared scratch session. + // use, warm-resumable); its absence uses the shared scratch session, + // which is always reset per request. const std::string id = req.value("session_id", std::string{}); - LLMSession* session = nullptr; + WorkerSessionState* st = nullptr; + bool warm = false; if (id.empty()) { - session = sessions.scratch(); + st = sessions.scratch(); } else { std::string code; - session = sessions.open_named(id, code); - if (session == nullptr) { + st = sessions.open_named(id, code); + if (st == nullptr) { worker_emit( {{"error", "cannot open session"}, {"code", code}, {"session_id", id}}); continue; } + warm = enable_warm_resume; } - worker_handle_request(*session, tokenizer, metadata, req); + worker_handle_request(*st, warm, tokenizer, metadata, req); } catch (const std::exception& e) { // report and keep serving worker_emit({{"error", std::string(e.what())}}); } diff --git a/extension/llm/server/cpp/worker_prefill_plan.h b/extension/llm/server/cpp/worker_prefill_plan.h new file mode 100644 index 00000000000..e5985bccbce --- /dev/null +++ b/extension/llm/server/cpp/worker_prefill_plan.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Pure warm-resume prefill decision for the model worker (no model/session/ET +// dependency, so it is unit-testable in isolation). Given a named session's +// resident token ids (exactly the tokens currently in its KV/recurrent state), +// the new request's prompt token ids, and whether the session is dirty, decide +// whether to reset + full-prefill or to keep the state and prefill only the +// suffix. The decision is exact-token (never string / retokenized text), so a +// kSuffix plan is always a correct continuation; anything uncertain falls back +// to kFull. See worker_loop.h for how the plan is executed. + +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +struct PrefillPlan { + enum Action { + kFull, // reset + prefill the whole prompt + kSuffix // keep state, prefill prompt_ids[suffix_start:] at pos>0 + } action; + size_t suffix_start; // index in prompt_ids where prefill begins (0 for kFull) + // Reported as session_reset_reason: "new" (no resident), "exact_prefix" + // (suffix reuse), "dirty", "mismatch", "equal" (prompt == resident). + const char* reason; +}; + +inline PrefillPlan plan_prefill( + const std::vector& resident, + const std::vector& prompt, + bool dirty) { + if (dirty) { + return {PrefillPlan::kFull, 0, "dirty"}; + } + if (resident.empty()) { + return {PrefillPlan::kFull, 0, "new"}; + } + if (prompt.size() < resident.size()) { + return {PrefillPlan::kFull, 0, "mismatch"}; + } + for (size_t i = 0; i < resident.size(); ++i) { + if (prompt[i] != resident[i]) { + return {PrefillPlan::kFull, 0, "mismatch"}; + } + } + if (prompt.size() == resident.size()) { + // Prompt is exactly the resident state (no new tokens). The ideal would be + // to skip prefill and decode straight from the session's pending token, but + // the LLMSession API exposes no "is there a valid pending token?" query, so + // we conservatively reset + full prefill rather than risk + // prefill_tokens([]) or decoding from a stale/absent pending. Rare in + // practice (a new turn adds tokens); a session pending-state query is a + // possible later optimization. + return {PrefillPlan::kFull, 0, "equal"}; + } + return {PrefillPlan::kSuffix, resident.size(), "exact_prefix"}; +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py index 4ba2539c247..05066758deb 100644 --- a/extension/llm/server/python/server.py +++ b/extension/llm/server/python/server.py @@ -127,6 +127,17 @@ async def close_session(session_id: str): return JSONResponse(e.body(), status_code=e.status) return JSONResponse({"closed": True, "session_id": session_id}) + @app.post("/v1/sessions/{session_id}/reset") + async def reset_session(session_id: str): + # Clear a named session's context but keep its slot (vendor extension; + # idempotent). Lets an agent reuse a slot for a new conversation without + # freeing/reopening it. + try: + await serving.reset_session(session_id) + except APIError as e: + return JSONResponse(e.body(), status_code=e.status) + return JSONResponse({"reset": True, "session_id": session_id}) + return app diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index 69e38c59b19..3b552228980 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -223,6 +223,13 @@ async def close_session(self, session_id: str) -> None: except WorkerError as e: raise GenerationError(str(e)) + async def reset_session(self, session_id: str) -> None: + self._validate_session_id(session_id) + try: + await self._runtime.reset(session_id) + except WorkerError as e: + raise GenerationError(str(e)) + def _finish_reason( self, req: ChatCompletionRequest, diff --git a/extension/llm/server/python/tests/test_sessions.py b/extension/llm/server/python/tests/test_sessions.py index ea5a6554d6a..c8a9cda3c57 100644 --- a/extension/llm/server/python/tests/test_sessions.py +++ b/extension/llm/server/python/tests/test_sessions.py @@ -137,3 +137,22 @@ def test_session_header_precedence(make_client): ) assert resp.status_code == 200 assert fake.opened_log == ["xet"] + + +def test_reset_endpoint_clears_context_but_keeps_slot(make_client): + # max_named=1: open "a", reset it, then a *different* id must still 429 — + # proving reset cleared context without freeing the slot (unlike DELETE). + client, fake = make_client(max_named_sessions=1) + assert _chat(client, session_id="a").status_code == 200 + r = client.post("/v1/sessions/a/reset") + assert r.status_code == 200 + assert r.json() == {"reset": True, "session_id": "a"} + assert fake.reset_log == ["a"] + assert _chat(client, session_id="b").status_code == 429 # slot still held + + +def test_reset_invalid_session_id_rejected(make_client): + client, _ = make_client(max_named_sessions=2) + r = client.post("/v1/sessions/has%20space/reset") + assert r.status_code == 400 + assert r.json()["error"]["code"] == "invalid_session_id" diff --git a/extension/llm/server/python/tests/test_worker_client.py b/extension/llm/server/python/tests/test_worker_client.py index b461785036f..59ef848f41a 100644 --- a/extension/llm/server/python/tests/test_worker_client.py +++ b/extension/llm/server/python/tests/test_worker_client.py @@ -175,6 +175,37 @@ def test_close_session_sends_op_and_acks(): assert json.loads(proc.stdin.written[0]) == {"op": "close", "session_id": "abc"} +def test_reset_session_sends_op_and_acks(): + proc = _FakeProc(_lines({"reset": True, "session_id": "abc"})) + WorkerClient(proc).reset_session("abc") + assert json.loads(proc.stdin.written[0]) == {"op": "reset", "session_id": "abc"} + + +def test_generate_parses_warm_resume_metrics(): + proc = _FakeProc( + _lines( + {"token": "hi"}, + { + "done": True, + "prompt_tokens": 100, + "completion_tokens": 1, + "finish_reason": "stop", + "reused_prompt_tokens": 90, + "prefilled_prompt_tokens": 10, + "session_reset_reason": "exact_prefix", + }, + ) + ) + seen = {} + WorkerClient(proc).generate( + "hi", _Cfg(session_id="s"), stats_callback=lambda s: seen.update(s=s) + ) + st = seen["s"] + assert st.reused_prompt_tokens == 90 + assert st.prefilled_prompt_tokens == 10 + assert st.session_reset_reason == "exact_prefix" + + def test_spawn_worker_waits_for_ready(): proc = _FakeProc(_lines({"ready": True, "max_named_sessions": 3})) client = spawn_worker(