From b9daa77b91af5cdcdf23a6740b1c76417d738f72 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 11:48:04 -0700 Subject: [PATCH 1/2] [INITIAL] Update [ghstack-poisoned] --- Makefile | 14 +- examples/models/qwen3_5_moe/CMakeLists.txt | 23 +- examples/models/qwen3_5_moe/CMakePresets.json | 28 ++ examples/models/qwen3_5_moe/README.md | 84 ++++ examples/models/qwen3_5_moe/main.cpp | 433 ++++++------------ examples/models/qwen3_5_moe/model.md | 29 ++ .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 379 +++++++++++++++ .../models/qwen3_5_moe/qwen35_moe_engine.h | 96 ++++ .../qwen3_5_moe/qwen35_moe_pybindings.cpp | 62 +++ examples/models/qwen3_5_moe/serve.py | 259 +++++++++++ examples/models/qwen3_5_moe/test_serve.py | 149 ++++++ examples/models/qwen3_5_moe/worker.py | 127 +++++ 12 files changed, 1381 insertions(+), 302 deletions(-) create mode 100644 examples/models/qwen3_5_moe/qwen35_moe_engine.cpp create mode 100644 examples/models/qwen3_5_moe/qwen35_moe_engine.h create mode 100644 examples/models/qwen3_5_moe/qwen35_moe_pybindings.cpp create mode 100644 examples/models/qwen3_5_moe/serve.py create mode 100644 examples/models/qwen3_5_moe/test_serve.py create mode 100644 examples/models/qwen3_5_moe/worker.py diff --git a/Makefile b/Makefile index 9c8476d30ed..4724f629969 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-cuda-serve qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -130,6 +130,7 @@ help: @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" + @echo " qwen3_5_moe-cuda-serve - Build Qwen3.5 MoE runner + OpenAI serving module (CUDA)" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -455,6 +456,17 @@ gemma4_31b-mlx: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" +qwen3_5_moe-cuda-serve: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Qwen3.5 MoE runner + serving module with CUDA..." + cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-cuda-serve + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner" + @echo " Serving module: cmake-out/examples/models/qwen3_5_moe/_qwen35_moe*.so" + @echo " Launch: see examples/models/qwen3_5_moe/README.md (Serving)" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index d1cfe54a56f..16fe1fc02b7 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -60,7 +60,7 @@ endif() # Tokenizer list(APPEND link_libraries tokenizers::tokenizers) -add_executable(qwen3_5_moe_runner main.cpp) +add_executable(qwen3_5_moe_runner main.cpp qwen35_moe_engine.cpp) target_include_directories( qwen3_5_moe_runner PUBLIC ${_common_include_directories} ) @@ -70,3 +70,24 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(qwen3_5_moe_runner) target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s") endif() + +# Example-local serving module (_qwen35_moe) used by the worker subprocess +# (worker.py). Off by default; enable with -DEXECUTORCH_BUILD_PYBIND=ON. It +# builds a Qwen35MoEEngine via create_engine() and re-exposes the generic Engine +# / Session surface (no Qwen-specific Python class). The generic +# extension/llm/runner pybind is not touched, so the model-agnostic +# server/runner never depend on Qwen. +if(EXECUTORCH_BUILD_PYBIND) + add_subdirectory( + ${EXECUTORCH_ROOT}/third-party/pybind11 + ${CMAKE_CURRENT_BINARY_DIR}/pybind11 + ) + pybind11_add_module( + _qwen35_moe SHARED qwen35_moe_pybindings.cpp qwen35_moe_engine.cpp + ) + target_include_directories(_qwen35_moe PRIVATE ${_common_include_directories}) + target_link_libraries(_qwen35_moe PRIVATE ${link_libraries}) + set_target_properties( + _qwen35_moe PROPERTIES CXX_STANDARD 17 POSITION_INDEPENDENT_CODE ON + ) +endif() diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 34ebc938280..28a4b3efa13 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -24,6 +24,14 @@ "list": ["Linux", "Windows"] } }, + { + "name": "qwen3-5-moe-cuda-serve", + "displayName": "Qwen3.5 MoE runner + serving module (CUDA)", + "inherits": ["qwen3-5-moe-cuda"], + "cacheVariables": { + "EXECUTORCH_BUILD_PYBIND": "ON" + } + }, { "name": "qwen3-5-moe-metal", "displayName": "Qwen3.5 MoE runner (Metal)", @@ -45,6 +53,12 @@ "configurePreset": "qwen3-5-moe-cuda", "targets": ["qwen3_5_moe_runner"] }, + { + "name": "qwen3-5-moe-cuda-serve", + "displayName": "Build Qwen3.5 MoE runner + serving module (CUDA)", + "configurePreset": "qwen3-5-moe-cuda-serve", + "targets": ["qwen3_5_moe_runner", "_qwen35_moe"] + }, { "name": "qwen3-5-moe-metal", "displayName": "Build Qwen3.5 MoE runner (Metal)", @@ -67,6 +81,20 @@ } ] }, + { + "name": "qwen3-5-moe-cuda-serve", + "displayName": "Configure and build Qwen3.5 MoE runner + serving module (CUDA)", + "steps": [ + { + "type": "configure", + "name": "qwen3-5-moe-cuda-serve" + }, + { + "type": "build", + "name": "qwen3-5-moe-cuda-serve" + } + ] + }, { "name": "qwen3-5-moe-metal", "displayName": "Configure and build Qwen3.5 MoE runner (Metal)", diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 83373a804f4..50e8440aea4 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -133,11 +133,95 @@ cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \ | `--data_path` | (none) | Path to `.ptd` delegate data file (required for CUDA) | | `--tokenizer_path` | (required) | Path to HuggingFace `tokenizer.json` | | `--prompt` | `"Hello"` | Input prompt text | +| `--prompt_file` | (none) | Path to a file with the prompt (overrides `--prompt`) | | `--temperature` | `0.8` | Sampling temperature (0 = greedy) | | `--max_new_tokens` | `128` | Maximum tokens to generate | +| `--cuda_graph` | off | Capture/replay the decode method as a CUDA graph (CUDA only). See the caveat below. | +| `--warmup` | `0` | Warmup iterations to discard before timing (one model load; the session is reset between iterations) | +| `--num_iters` | `1` | Timed iterations to average, after warmup | + +## Serving (OpenAI-compatible) + +Run an OpenAI-compatible HTTP server so an agent harness (pi, opencode, …) can +use the model for local tool-use. Point your client at `http://:/v1`. + +Build the runner **and** the serving module: + +```bash +make qwen3_5_moe-cuda-serve +``` + +Launch (the `LD_LIBRARY_PATH` shim is forwarded to the worker for the CUDA blob): + +```bash +LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ + python -m executorch.examples.models.qwen3_5_moe.serve \ + --model-path qwen35_moe_exports/model.pte \ + --data-path qwen35_moe_exports/aoti_cuda_blob.ptd \ + --tokenizer-path ~/models/Qwen3.5-35B-A3B/tokenizer.json \ + --hf-tokenizer ~/models/Qwen3.5-35B-A3B \ + --model-id qwen3.5-moe --no-think +``` + +### Architecture (process isolation) + +Two processes, one model load: + +``` +serve.py (control plane: FastAPI/asyncio, OpenAI protocol, chat templating, + tool parsing, validation — NO CUDA) + │ JSONL over stdin/stdout + ▼ +worker.py (one Qwen35MoEEngine + one session, synchronous loop — the CUDA model; + NO asyncio server) +``` + +The model runs in a **separate worker process** because executing the AOTI CUDA +model inside a live asyncio server process segfaults in the int4 matmul +(reproducible, and isolated by elimination to the asyncio-loop × CUDA +interaction). The worker runs the model like the CLI — a plain synchronous loop — +which is reliable. The control plane only does blocking pipe I/O (no CUDA), which +is safe under asyncio. + +### Serve Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--model-path` | (required) | Path to exported `.pte` model | +| `--data-path` | (none) | Path to `.ptd` delegate data file (required for CUDA) | +| `--tokenizer-path` | (required) | Path to HuggingFace `tokenizer.json` | +| `--hf-tokenizer` | (required) | HF tokenizer id/dir for the chat template + encoding | +| `--model-id` | `qwen3.5-moe` | Model id reported on `/v1/models` | +| `--host` / `--port` | `127.0.0.1` / `8000` | Bind address | +| `--max-context` | (none) | Reject prompts that exceed it with 400 | +| `--no-think` | off | Default reasoning off (`enable_thinking=False`) | + +### V1 limitations + +- **Single-slot** (`serving_capacity=1`): one worker, one session, one model + load. `--num-runners > 1` is rejected; concurrent requests queue on the worker. +- **No prefix cache**: the recurrent/conv state cannot be rewound by position + (`seek()` is NotSupported), so turn-to-turn KV reuse is off. +- Supports the chat-completions contract of the generic server; `top_p != 1`, + `seed`, `top_k`, `logprobs`, etc. are rejected (only temperature is plumbed). ## Troubleshooting +- **Runner exits silently right after `Loading methods...`**: the AOTI CUDA blob + is compiled with the conda toolchain's `libstdc++`, which is newer than the + system one (it needs e.g. `GLIBCXX_3.4.34`). Prepend the conda lib dir so the + runner loads the matching `libstdc++`: + + ```bash + LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \ + cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner ... + ``` +- **`aoti_torch_cuda_sort_stable ... API call failed` when re-running prefill + with `--cuda_graph`**: capturing the decode CUDA graph and then running another + prefill in the same process currently fails (allocator interaction). Use + `--cuda_graph` for single prefill+decode runs; omit it when looping with + `--warmup`/`--num_iters`. + - **OOM during export**: The model requires significant GPU memory even with int4 quantization. Try reducing `--max-seq-len` or using a GPU with more VRAM. diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 19d93af0d58..1bd7b495e2b 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -6,17 +6,17 @@ * LICENSE file in the root directory of this source tree. */ +// Thin CLI over Qwen35MoEEngine / Qwen35MoESession: parse flags, build the +// engine + a session, encode the prompt, prefill_tokens(), then loop +// decode_one() printing pieces and timing/stats. All model execution lives in +// qwen35_moe_engine.{h,cpp}. + #include -#include +#include #include #include -#include -#include -#include -#include #include -#include #include #include @@ -26,8 +26,6 @@ #ifdef EXECUTORCH_BUILD_CUDA #include -#else -#include #endif DEFINE_string(model_path, "", "Model .pte file path."); @@ -44,57 +42,16 @@ DEFINE_bool( cuda_graph, false, "Enable CUDA graph for decode method. CUDA only."); +DEFINE_int32( + warmup, + 0, + "Warmup iterations to discard before timing. One model load; the session is " + "reset between iterations. Warmup captures the CUDA graph and ramps GPU " + "clocks so the timed iterations reflect steady state."); +DEFINE_int32(num_iters, 1, "Timed iterations to average (after warmup)."); namespace llm = ::executorch::extension::llm; -using ::executorch::extension::from_blob; -using ::executorch::extension::Module; -using ::executorch::extension::TensorPtr; using ::executorch::runtime::Error; -using ::executorch::runtime::EValue; - -using SizesType = executorch::aten::SizesType; - -// Convert a model output tensor to the next sampled token id. -// -// On the CUDA build, the model fuses the sampler in (see sampler.py / -// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1] -// float tensor; we just copy that scalar back from device. -// -// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits -// of shape [B, T, V] in the model dtype (typically bf16). We sample on -// CPU via the shared `llm::logits_to_token` helper, which accepts a -// temperature (0 = greedy / argmax). -static uint64_t read_token(const executorch::aten::Tensor& output) { -#ifdef EXECUTORCH_BUILD_CUDA - const void* ptr = output.const_data_ptr(); - - cudaPointerAttributes attrs; - bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && - attrs.type == cudaMemoryTypeDevice; - - float val; - if (on_device) { - cudaError_t err = - cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); - if (err != cudaSuccess) { - ET_LOG( - Error, - "read_token: cudaMemcpy D2H failed: %s", - cudaGetErrorString(err)); - return 0; - } - } else { - memcpy(&val, ptr, sizeof(float)); - } - return static_cast(val); -#else - // logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 / - // UInt16 dtypes. Negative temperatures are clamped to 0 (greedy). - const float temp = - FLAGS_temperature <= 0.0 ? 0.0f : static_cast(FLAGS_temperature); - return static_cast(llm::logits_to_token(output, temp)); -#endif -} int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -111,28 +68,6 @@ int main(int argc, char** argv) { llm::Stats stats; #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory before load - size_t gpu_free_bytes = 0, gpu_total_bytes = 0; - cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); - stats.gpu_total_bytes = gpu_total_bytes; - stats.gpu_free_before_load_bytes = gpu_free_bytes; -#endif - - stats.model_load_start_ms = llm::time_in_ms(); - - // Load tokenizer - auto tokenizer = std::make_unique(); - auto tok_status = tokenizer->load(FLAGS_tokenizer_path); - if (tok_status != tokenizers::Error::Ok) { - ET_LOG( - Error, - "Failed to load tokenizer from %s", - FLAGS_tokenizer_path.c_str()); - return 1; - } - -#ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: before load { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -144,90 +79,32 @@ int main(int argc, char** argv) { stats.model_load_start_ms = llm::time_in_ms(); - // Create Module with share_memory_arenas=true so prefill and decode - // share mutable buffers (KV cache, conv_state, recurrent_state). - std::vector data_files; - if (!FLAGS_data_path.empty()) { - data_files.push_back(FLAGS_data_path); - } - auto module = std::make_unique( - FLAGS_model_path, - data_files, - Module::LoadMode::File, - /*event_tracer=*/nullptr, - /*memory_allocator=*/nullptr, - /*temp_allocator=*/nullptr, - /*share_memory_arenas=*/true); - - // Get metadata - auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); - if (metadata_result.error() != Error::Ok) { - ET_LOG(Error, "Failed to get metadata from model"); - return 1; - } - auto metadata = metadata_result.get(); - -#ifdef EXECUTORCH_BUILD_CUDA - // Set CUDA graph option if requested (must be before load_method) - if (FLAGS_cuda_graph) { - executorch::runtime::BackendOptions<2> cuda_opts; - cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); - executorch::runtime::set_option("CudaBackend", cuda_opts.view()); - printf("CUDA graph enabled for decode method\n"); - } -#else - if (FLAGS_cuda_graph) { - ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); - } -#endif + // Build engine (reads tokenizer + metadata) and a session (loads weights and + // the prefill/decode methods). + llm::Qwen35MoEConfig config; + config.model_path = FLAGS_model_path; + config.data_path = FLAGS_data_path; + config.tokenizer_path = FLAGS_tokenizer_path; + config.cuda_graph = FLAGS_cuda_graph; printf("Loading methods...\n"); - -#ifdef EXECUTORCH_BUILD_CUDA - // Enable cross-method per-FQN weight sharing in the CUDA backend so that - // prefill and decode (which share KV cache and other mutable buffers / - // weights) avoid duplicate GPU allocations. This is critical for fitting - // Qwen 3.5 MoE on a single GPU. MUST be set BEFORE load_method, since the - // backend reads this flag during init() to decide between the per-weight - // cache path and the legacy per-method blob load. - { - executorch::runtime::BackendOptions<1> backend_options; - auto set_err = - backend_options.set_option("weight_sharing_across_methods", true); - if (set_err != Error::Ok) { - ET_LOG( - Error, - "Failed to construct weight_sharing_across_methods option: %d", - static_cast(set_err)); - return 1; - } - const auto opt_err = - executorch::runtime::set_option("CudaBackend", backend_options.view()); - if (opt_err != Error::Ok) { - ET_LOG( - Error, - "Failed to enable weight_sharing_across_methods: %d", - static_cast(opt_err)); - return 1; - } - } -#endif - - auto err = module->load_method("prefill"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load prefill method"); + auto engine_result = llm::Qwen35MoEEngine::create(config); + if (engine_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create Qwen3.5 MoE engine"); return 1; } - err = module->load_method("decode"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load decode method"); + auto engine = std::move(engine_result.get()); + + auto session_result = engine->create_session(); + if (session_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to create session"); return 1; } + auto session = std::move(session_result.get()); stats.model_load_end_ms = llm::time_in_ms(); #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: after load { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -236,10 +113,7 @@ int main(int argc, char** argv) { } #endif - // Get EOS ids - auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); - - // Read prompt from file or flag + // Read prompt from file or flag. std::string prompt_text = FLAGS_prompt; if (!FLAGS_prompt_file.empty()) { std::ifstream f(FLAGS_prompt_file); @@ -252,157 +126,100 @@ int main(int argc, char** argv) { (std::istreambuf_iterator(f)), std::istreambuf_iterator()); } - // Encode prompt - auto encode_result = tokenizer->encode(prompt_text); + // Encode prompt via the engine's tokenizer. + auto encode_result = engine->tokenizer()->encode(prompt_text); if (!encode_result.ok()) { ET_LOG(Error, "Failed to encode prompt"); return 1; } - auto prompt_tokens = std::move(*encode_result); - int64_t num_prompt_tokens = prompt_tokens.size(); + std::vector prompt_tokens = std::move(*encode_result); + const int64_t num_prompt_tokens = static_cast(prompt_tokens.size()); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; - stats.inference_start_ms = llm::time_in_ms(); - // --------------------------------------------------------------- - // Sampling tensors (shared between prefill and decode) - // --------------------------------------------------------------- - auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - -#ifdef EXECUTORCH_BUILD_CUDA - // CUDA build: model fuses the sampler in. Pass a temperature tensor as - // a third input. Use a very small temperature for greedy to avoid - // division by zero while keeping the Gumbel noise negligible relative - // to logit differences. - float temp_val = - FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); -#endif - - stats.inference_start_ms = llm::time_in_ms(); - stats.num_prompt_tokens = num_prompt_tokens; - - // --------------------------------------------------------------- - // Prefill - // --------------------------------------------------------------- - uint64_t cur_token = 0; - - // Use prefill method for T>=2, decode method for T=1 - // (prefill was exported with min seq_len=2) - std::string run_method = "prefill"; - if (num_prompt_tokens == 1) { - run_method = "decode"; - } - - std::vector pos_data(num_prompt_tokens); - for (int64_t i = 0; i < num_prompt_tokens; i++) { - pos_data[i] = i; - } - std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); - auto tokens_tensor = from_blob( - token_data.data(), - {1, S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - auto pos_tensor = from_blob( - pos_data.data(), - {S(num_prompt_tokens)}, - executorch::aten::ScalarType::Long); - - std::vector prefill_inputs; - prefill_inputs.push_back(tokens_tensor); - prefill_inputs.push_back(pos_tensor); -#ifdef EXECUTORCH_BUILD_CUDA - prefill_inputs.push_back(temp_tensor); -#endif - - auto prefill_result = module->execute(run_method, prefill_inputs); - if (prefill_result.error() != Error::Ok) { - ET_LOG(Error, "Prefill failed"); - return 1; - } - auto& prefill_outputs = prefill_result.get(); - - cur_token = read_token(prefill_outputs[0].toTensor()); - - stats.prompt_eval_end_ms = llm::time_in_ms(); - stats.first_token_ms = stats.prompt_eval_end_ms; - double prefill_ms = - (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - printf( - "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", - num_prompt_tokens, - prefill_ms, - num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); - -#ifdef EXECUTORCH_BUILD_CUDA - // Synchronize CUDA device to ensure prefill's writes to shared mutable - // buffers (KV cache, conv_state, recurrent_state) are visible to the - // decode method, which may run on a different CUDA stream. - cudaDeviceSynchronize(); -#endif - - // --------------------------------------------------------------- - // Decode — generate tokens one at a time - // --------------------------------------------------------------- - int64_t pos = num_prompt_tokens; - uint64_t prev_token; - - std::vector decode_token_data = {static_cast(cur_token)}; - std::vector decode_pos_data = {pos}; - auto decode_tokens = from_blob( - decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); - auto decode_pos = from_blob( - decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); - - for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { - decode_token_data[0] = static_cast(cur_token); - decode_pos_data[0] = pos; - - std::vector decode_inputs; - decode_inputs.push_back(EValue(decode_tokens)); - decode_inputs.push_back(EValue(decode_pos)); -#ifdef EXECUTORCH_BUILD_CUDA - decode_inputs.push_back(EValue(temp_tensor)); -#endif - - auto decode_result = module->execute("decode", decode_inputs); - if (decode_result.error() != Error::Ok) { - ET_LOG(Error, "Decode step %d failed", step); + // Warmup + timed iterations on one loaded session (reset between). The first + // FLAGS_warmup iterations are discarded; they trigger CUDA-graph capture, + // allocator growth, and GPU clock ramp so the timed iterations reflect steady + // state. Text is printed only on the first iteration (coherence check). + llm::SamplingConfig sampling; + sampling.temperature = static_cast(FLAGS_temperature); + const int total_iters = FLAGS_warmup + std::max(1, FLAGS_num_iters); + std::vector prefill_tps_samples; + std::vector decode_tps_samples; + double prefill_ms = 0.0; + int64_t num_generated = 0; + + for (int iter = 0; iter < total_iters; ++iter) { + if (iter > 0 && session->reset() != Error::Ok) { + ET_LOG(Error, "Session reset failed before iteration %d", iter); return 1; } - auto& decode_outputs = decode_result.get(); + const bool measured = iter >= FLAGS_warmup; + const bool print_text = (iter == 0); - prev_token = cur_token; - cur_token = read_token(decode_outputs[0].toTensor()); - - if (step == 0) { - stats.first_token_ms = llm::time_in_ms(); + stats.inference_start_ms = llm::time_in_ms(); + if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) { + ET_LOG(Error, "Prefill failed"); + return 1; } - - pos++; - - auto decode_str = tokenizer->decode(prev_token, cur_token); - if (decode_str.ok()) { - printf("%s", decode_str->c_str()); - fflush(stdout); + stats.prompt_eval_end_ms = llm::time_in_ms(); + stats.first_token_ms = stats.prompt_eval_end_ms; + + num_generated = 0; + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + auto step_result = session->decode_one(sampling); + if (step_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + const auto& d = step_result.get(); + num_generated++; + if (step == 0) { + stats.first_token_ms = llm::time_in_ms(); + } + if (print_text && !d.text_piece.empty()) { + fwrite(d.text_piece.data(), 1, d.text_piece.size(), stdout); + fflush(stdout); + } + if (d.is_eos) { + if (print_text) { + printf("\n"); + } + break; + } } - - if (eos_ids.find(cur_token) != eos_ids.end()) { - printf("\n"); - break; + stats.inference_end_ms = llm::time_in_ms(); + stats.num_generated_tokens = num_generated; + + prefill_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + const double decode_ms_iter = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + const double pf_tps = + num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND; + const double dc_tps = + num_generated / decode_ms_iter * stats.SCALING_FACTOR_UNITS_PER_SECOND; + printf( + "[iter %d%s] prefill %.1f tok/s (%" PRId64 + " tok, %.1f ms) | " + "decode %.1f tok/s (%" PRId64 " tok, %.1f ms)\n", + iter, + measured ? "" : " warmup", + pf_tps, + num_prompt_tokens, + prefill_ms, + dc_tps, + num_generated, + decode_ms_iter); + if (measured) { + prefill_tps_samples.push_back(pf_tps); + decode_tps_samples.push_back(dc_tps); } } - stats.inference_end_ms = llm::time_in_ms(); - printf("\n"); - int64_t num_generated = pos - num_prompt_tokens; - stats.num_generated_tokens = num_generated; #ifdef EXECUTORCH_BUILD_CUDA - // GPU memory: after generate + peak usage { size_t free = 0, total = 0; if (cudaMemGetInfo(&free, &total) == cudaSuccess) { @@ -420,8 +237,7 @@ int main(int argc, char** argv) { #endif printf("\n"); - - double decode_ms = + const double decode_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); printf( "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", @@ -435,21 +251,20 @@ int main(int argc, char** argv) { num_generated / decode_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); - // Structured stats report (matches stats.h print_report) printf("PyTorchObserver %s\n", llm::stats_to_json_string(stats).c_str()); - double ms_per_s = stats.SCALING_FACTOR_UNITS_PER_SECOND; - - double model_load_s = + const double ms_per_s = stats.SCALING_FACTOR_UNITS_PER_SECOND; + const double model_load_s = (double)(stats.model_load_end_ms - stats.model_load_start_ms) / ms_per_s; - double inference_time_ms = + const double inference_time_ms = (double)(stats.inference_end_ms - stats.inference_start_ms); - double prompt_eval_ms = + const double prompt_eval_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - double eval_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); - double ttft_s = + const double eval_ms = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + const double ttft_s = (double)(stats.first_token_ms - stats.inference_start_ms) / ms_per_s; - double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s; + const double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s; printf("\n"); printf( @@ -477,7 +292,6 @@ int main(int argc, char** argv) { stats.num_prompt_tokens + stats.num_generated_tokens, sampling_s); - // GPU memory reporting if (stats.gpu_total_bytes != static_cast(-1)) { printf( "\tGPU total memory: %.2f MB\n", @@ -502,5 +316,24 @@ int main(int argc, char** argv) { } } + if (!prefill_tps_samples.empty()) { + auto mean = [](const std::vector& v) { + double s = 0.0; + for (double x : v) { + s += x; + } + return s / v.size(); + }; + printf( + "\n=== mean over %zu timed iter(s) (warmup %d) | prompt %" PRId64 + ", gen %" PRId64 " ===\n", + prefill_tps_samples.size(), + FLAGS_warmup, + num_prompt_tokens, + num_generated); + printf("\tPrefill: %.1f tok/s\n", mean(prefill_tps_samples)); + printf("\tDecode: %.1f tok/s\n", mean(decode_tps_samples)); + } + return 0; } diff --git a/examples/models/qwen3_5_moe/model.md b/examples/models/qwen3_5_moe/model.md index 32510859b28..d29177c4c87 100644 --- a/examples/models/qwen3_5_moe/model.md +++ b/examples/models/qwen3_5_moe/model.md @@ -136,6 +136,35 @@ matmul). Visual and MTP keys are skipped. `lm_head.weight` is cloned from `embed_tokens.weight` if not present in checkpoint (tied embeddings). +## Serving (Engine/Session adapter) + +`main.cpp` is a thin CLI over `Qwen35MoEEngine` / `Qwen35MoESession` +(`qwen35_moe_engine.{h,cpp}`), which implement the model-agnostic +`LLMEngine` / `LLMSession` serving contract in +`extension/llm/runner/llm_session.h`. This lets an OpenAI-compatible server (or +any harness) drive the model without knowing it is Qwen-MoE or CUDA. + +- **`Qwen35MoEEngine`** owns immutable resources (tokenizer, metadata, EOS ids, + config). `create_session()` builds a `Module` with `share_memory_arenas=true` + and, on CUDA, sets the backend options that must precede `load_method` + (`weight_sharing_across_methods`, optional `enable_cuda_graph_for_method`), + then loads the `prefill`/`decode` methods. `serving_capacity()` reports a + single physical session — cross-session weight sharing is not yet proven, so + it fails closed to 1. +- **`Qwen35MoESession`** owns the mutable conversation state (KV / conv / + recurrent arenas via the Module, position cursor, pending token). + `prefill_tokens` dispatches to `prefill` (T≥2) or `decode` (T==1); + `decode_one` emits the pending token and forwards it, stopping at EOS without + forwarding it (EOS is not made resident and `position()` does not advance). + `seek()` returns `NotSupported` — the recurrent/conv state cannot be rewound + by logical position. `reset()` is a logical rewind to position 0; the model + zeroes `conv_state`/`recurrent_state` whenever prefill runs at + `input_pos[0]==0`, so no Module rebuild is needed. +- Backend-specific execution (CUDA in-graph sampling via a temperature input, + device sync, backend options) is isolated behind `EXECUTORCH_BUILD_CUDA` — the + extension point where an MLX runtime would slot in. The public + `LLMEngine`/`LLMSession` surface stays backend-agnostic. + ## References - [HF Transformers Qwen3.5 MoE](https://github.com/huggingface/transformers) — `transformers/models/qwen3_5_moe/` diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp new file mode 100644 index 00000000000..3d2d3955cbf --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -0,0 +1,379 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef EXECUTORCH_BUILD_CUDA +#include +#else +#include +#endif + +namespace executorch::extension::llm { + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; +using SizesType = executorch::aten::SizesType; + +namespace { + +// --------------------------------------------------------------------------- +// Backend-specific helpers (the MLX extension points live here). On CUDA the +// model fuses the sampler in and returns the sampled token id as a [B,1] float; +// non-CUDA returns logits and we sample on host. Keep these isolated so the +// session logic below stays backend-agnostic. +// --------------------------------------------------------------------------- + +uint64_t read_sampled_token( + const executorch::aten::Tensor& output, + float temperature) { +#ifdef EXECUTORCH_BUILD_CUDA + (void)temperature; + const void* ptr = output.const_data_ptr(); + cudaPointerAttributes attrs; + const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + float val = 0.0f; + if (on_device) { + if (cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost) != + cudaSuccess) { + ET_LOG(Error, "read_sampled_token: cudaMemcpy D2H failed"); + return 0; + } + } else { + std::memcpy(&val, ptr, sizeof(float)); + } + return static_cast(val); +#else + return static_cast( + logits_to_token(output, temperature < 0.0f ? 0.0f : temperature)); +#endif +} + +// Build a Qwen Module with shared mutable arenas (so prefill and decode share +// KV/conv/recurrent state) and, on CUDA, the weight-sharing/cuda-graph backend +// options that MUST be set before load_method. Loads the prefill+decode methods +// (this is the heavy ~weights load). Shared by create_session() and reset(). +Result> build_qwen_module( + const Qwen35MoEConfig& config) { + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto module = std::make_unique( + config.model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + +#ifdef EXECUTORCH_BUILD_CUDA + // Backend options are read during backend init(), so they must be set before + // load_method. + if (config.cuda_graph) { + executorch::runtime::BackendOptions<1> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", cuda_opts.view())); + } + { + // Cross-method per-FQN weight sharing: prefill and decode reuse one weight + // allocation instead of duplicating it (critical to fit on one GPU). + executorch::runtime::BackendOptions<1> backend_options; + ET_CHECK_OK_OR_RETURN_ERROR( + backend_options.set_option("weight_sharing_across_methods", true)); + ET_CHECK_OK_OR_RETURN_ERROR( + executorch::runtime::set_option("CudaBackend", backend_options.view())); + } +#endif + + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill")); + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode")); + return module; +} + +// LLMSession over the Qwen3.5 MoE prefill/decode methods. Owns one physical +// Module (one weight allocation + its KV/recurrent/conv state). Internal: the +// server depends only on the LLMSession base. +class Qwen35MoESession : public LLMSession { + public: + Qwen35MoESession( + std::unique_ptr module, + ::tokenizers::Tokenizer* tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids) + : module_(std::move(module)), + tokenizer_(tokenizer), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)) { + // Persistent single-step decode buffers: stable addresses are required so + // CUDA-graph capture (which records buffer pointers) can replay each step. + decode_tokens_ = from_blob( + decode_token_data_, {1, 1}, executorch::aten::ScalarType::Long); + decode_pos_ = + from_blob(decode_pos_data_, {1}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + temp_tensor_ = + from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float); +#endif + } + + Error prefill_tokens( + std::vector tokens, + const SamplingConfig* initial_sampling) override { + if (tokens.empty()) { + ET_LOG(Error, "prefill_tokens: empty token list"); + return Error::InvalidArgument; + } + // The model samples the FIRST generated token in-graph during prefill, so + // it must use the request's sampling, not a stale session default. Only + // temperature is plumbed; reject non-default top_p/top_k/seed (parity with + // decode_one). + float first_token_temp = temperature_; + if (initial_sampling != nullptr) { + if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || + initial_sampling->seed != 0) { + ET_LOG( + Error, + "prefill_tokens: only temperature is supported; top_p/top_k/seed " + "are not yet implemented"); + return Error::NotSupported; + } + first_token_temp = initial_sampling->temperature; + } + const int64_t T = static_cast(tokens.size()); + const auto ctx_it = metadata_.find(kMaxContextLen); + if (ctx_it != metadata_.end() && pos_ + T > ctx_it->second) { + ET_LOG( + Error, + "prefill_tokens would exceed context capacity (pos %" PRId64 + " + %" PRId64 " > %" PRId64 ")", + pos_, + T, + ctx_it->second); + return Error::InvalidArgument; + } + + std::vector token_data(tokens.begin(), tokens.end()); + std::vector pos_data(T); + for (int64_t i = 0; i < T; ++i) { + pos_data[i] = pos_ + i; + } + auto tokens_tensor = from_blob( + token_data.data(), + {1, static_cast(T)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {static_cast(T)}, + executorch::aten::ScalarType::Long); + + // prefill method handles T>=2; the model exports decode for the T==1 case. + const char* method = (T >= 2) ? "prefill" : "decode"; + std::vector inputs; + inputs.push_back(tokens_tensor); + inputs.push_back(pos_tensor); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(first_token_temp); + inputs.push_back(EValue(temp_tensor_)); +#endif + auto res = module_->execute(method, inputs); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + pending_ = read_sampled_token(res.get()[0].toTensor(), first_token_temp); + prev_decode_token_.reset(); + pos_ += T; // the prompt tokens are now resident in KV/state +#ifdef EXECUTORCH_BUILD_CUDA + // Make prefill's writes to the shared mutable arenas visible to decode + // (which may run on a different stream). + cudaDeviceSynchronize(); +#endif + return Error::Ok; + } + + Result decode_one(const SamplingConfig& sampling) override { + // Only temperature is plumbed; reject the rest rather than silently ignore + // (callers must not assume top_p/top_k/seed are applied). + if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + ET_LOG( + Error, + "Qwen35MoESession: only temperature is supported; top_p/top_k/seed " + "are not implemented"); + return Error::NotSupported; + } + ET_CHECK_OR_RETURN_ERROR( + pending_.has_value(), + InvalidState, + "decode_one requires a pending token; call prefill_tokens() first"); + temperature_ = sampling.temperature; + + const uint64_t token = pending_.value(); + const bool is_eos = eos_ids_.find(token) != eos_ids_.end(); + + // Decode the text piece with BPE context (previous token); surface + // tokenizer errors instead of hiding them as empty text. + const uint64_t prev = prev_decode_token_.value_or(token); + auto dec = tokenizer_->decode(prev, token); + if (!dec.ok()) { + ET_LOG( + Error, + "Tokenizers error code %d", + static_cast(dec.error())); + return Error::InvalidArgument; + } + std::string text_piece = std::move(*dec); + + // Stop at EOS WITHOUT forwarding it: like the reference runner, EOS is not + // made resident and position() does not advance. No pending token remains. + if (is_eos) { + pending_.reset(); + return DecodeResult{token, std::move(text_piece), true}; + } + + // Forward `token` at pos_ through the decode method to get the next pending + // token. Update the persistent buffers in place (stable addresses). + decode_token_data_[0] = static_cast(token); + decode_pos_data_[0] = pos_; + std::vector inputs; + inputs.push_back(EValue(decode_tokens_)); + inputs.push_back(EValue(decode_pos_)); +#ifdef EXECUTORCH_BUILD_CUDA + set_temp(temperature_); + inputs.push_back(EValue(temp_tensor_)); +#endif + auto res = module_->execute("decode", inputs); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + pending_ = read_sampled_token(res.get()[0].toTensor(), temperature_); + prev_decode_token_ = token; + pos_ += 1; + return DecodeResult{token, std::move(text_piece), false}; + } + + Error seek(int64_t pos) override { + // The hybrid model carries recurrent/conv state that cannot be safely + // rewound by logical position the way contiguous KV can. Fail closed so the + // prefix cache falls back to reset + full prefill (V1). + (void)pos; + return Error::NotSupported; + } + + int64_t position() const override { + return pos_; + } + + Error reset() override { + // Logical reset is sufficient: the model zeroes conv_state/recurrent_state + // whenever prefill runs at input_pos[0]==0 (model.py), and a fresh prefill + // overwrites the KV cache at [0, T). So rewinding to position 0 and + // clearing the pending token gives a clean conversation without a Module + // rebuild. + pos_ = 0; + pending_.reset(); + prev_decode_token_.reset(); + stop_.store(false, std::memory_order_relaxed); + return Error::Ok; + } + + void stop() override { + // Cooperative, token-boundary: the driving loop checks between decode_one() + // calls. A single decode_one() forward is not interruptible. + stop_.store(true, std::memory_order_relaxed); + } + + private: +#ifdef EXECUTORCH_BUILD_CUDA + // Greedy (temperature <= 0) maps to a tiny temperature so the in-graph + // sampler avoids division by zero while staying effectively argmax. + void set_temp(float t) { + temp_val_ = (t <= 0.0f) ? 1e-6f : t; + } +#endif + + std::unique_ptr module_; + ::tokenizers::Tokenizer* tokenizer_; // non-owning; owned by the engine + std::unordered_map metadata_; + std::unordered_set eos_ids_; + + int64_t pos_ = 0; + std::optional pending_; + std::optional prev_decode_token_; + float temperature_ = -1.0f; + std::atomic stop_{false}; + + // Persistent single-step decode buffers (stable addresses for CUDA graph). + int64_t decode_token_data_[1] = {0}; + int64_t decode_pos_data_[1] = {0}; + TensorPtr decode_tokens_; + TensorPtr decode_pos_; +#ifdef EXECUTORCH_BUILD_CUDA + float temp_val_ = 1e-6f; + TensorPtr temp_tensor_; +#endif +}; + +} // namespace + +Result> Qwen35MoEEngine::create( + const Qwen35MoEConfig& config) { + if (config.model_path.empty() || config.tokenizer_path.empty()) { + ET_LOG( + Error, "Qwen35MoEEngine: model_path and tokenizer_path are required"); + return Error::InvalidArgument; + } + + auto tokenizer = std::make_unique<::tokenizers::HFTokenizer>(); + if (tokenizer->load(config.tokenizer_path) != ::tokenizers::Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: failed to load tokenizer from %s", + config.tokenizer_path.c_str()); + return Error::InvalidArgument; + } + + // Read metadata + eos from a lightweight Module (program + tiny metadata + // methods only; the heavy prefill/decode weights are NOT loaded here). + std::vector data_files; + if (!config.data_path.empty()) { + data_files.push_back(config.data_path); + } + auto meta_module = std::make_unique( + config.model_path, data_files, Module::LoadMode::File); + auto metadata_result = get_llm_metadata(tokenizer.get(), meta_module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Qwen35MoEEngine: failed to read metadata"); + return metadata_result.error(); + } + auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get()); + + return std::unique_ptr(new Qwen35MoEEngine( + config, std::move(tokenizer), metadata_result.get(), std::move(eos_ids))); +} + +Result> Qwen35MoEEngine::create_session() { + auto module = build_qwen_module(config_); + ET_CHECK_OK_OR_RETURN_ERROR(module.error()); + return std::unique_ptr(new Qwen35MoESession( + std::move(module.get()), tokenizer_.get(), metadata_, eos_ids_)); +} + +} // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.h b/examples/models/qwen3_5_moe/qwen35_moe_engine.h new file mode 100644 index 00000000000..9fb9e99d71e --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Engine/Session adapter for the Qwen3.5 MoE model, implementing the +// model-agnostic LLMEngine/LLMSession serving contract (llm_session.h) over the +// model's exported prefill/decode methods. +// +// The public surface is backend-agnostic: the server receives an LLMEngine and +// never branches on CUDA vs MLX. Backend-specific execution (CUDA in-graph +// sampling, weight-sharing/cuda-graph backend options, device sync) is isolated +// behind EXECUTORCH_BUILD_CUDA inside the .cpp; those isolated points are where +// an MLX runtime would slot in. MLX is NOT implemented or validated here. +// +// V1: serving_capacity() reports a single physical session (one Module = one +// weight allocation). Multiple weight-sharing sessions are a measured V2 step. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +/// Immutable configuration for a Qwen3.5 MoE engine. +struct Qwen35MoEConfig { + std::string model_path; // .pte + std::string data_path; // .ptd (CUDA delegate blob); empty if none + std::string tokenizer_path; // HuggingFace tokenizer.json + bool cuda_graph = false; // enable CUDA graph capture for the decode method +}; + +/// Engine over one loaded Qwen3.5 MoE Program. Owns immutable resources +/// (tokenizer, metadata, eos ids, config) and creates sessions that each own a +/// physical Module with its own KV/recurrent/conv state. +class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { + public: + static ::executorch::runtime::Result> create( + const Qwen35MoEConfig& config); + + ::executorch::runtime::Result> create_session() + override; + + // V1: one physical session; weight sharing across sessions is unproven, so we + // fail closed to 1 (the server queues concurrent requests on the resident + // session rather than duplicating ~18GB of weights). + LLMServingCapacity serving_capacity() const override { + return LLMServingCapacity{}; + } + + const std::unordered_map& metadata() const override { + return metadata_; + } + + // Non-owning; valid for the engine's lifetime (the engine must outlive any + // session and any caller using this). Used by the runner to encode prompts; + // not part of the model-agnostic LLMEngine surface the server depends on. + ::tokenizers::Tokenizer* tokenizer() const { + return tokenizer_.get(); + } + + Qwen35MoEEngine(const Qwen35MoEEngine&) = delete; + Qwen35MoEEngine& operator=(const Qwen35MoEEngine&) = delete; + + private: + Qwen35MoEEngine( + Qwen35MoEConfig config, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unordered_map metadata, + std::unordered_set eos_ids) + : config_(std::move(config)), + tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + eos_ids_(std::move(eos_ids)) {} + + Qwen35MoEConfig config_; + std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; + std::unordered_map metadata_; + std::unordered_set eos_ids_; +}; + +} // namespace executorch::extension::llm diff --git a/examples/models/qwen3_5_moe/qwen35_moe_pybindings.cpp b/examples/models/qwen3_5_moe/qwen35_moe_pybindings.cpp new file mode 100644 index 00000000000..5f68c777110 --- /dev/null +++ b/examples/models/qwen3_5_moe/qwen35_moe_pybindings.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Example-local serving module for Qwen3.5 MoE. It does NOT define any +// Qwen-specific Python class: it constructs a Qwen35MoEEngine and hands it to +// the generic PyLLMEngine wrapper (llm_pybind_wrappers.h), so the Python +// surface is the same generic LLMEngine / LLMSession the text model exposes. +// The generic extension/llm/runner pybind is untouched. + +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace py = pybind11; +using namespace executorch::extension::llm; + +PYBIND11_MODULE(_qwen35_moe, m) { + m.doc() = + "Example-local Qwen3.5 MoE serving module: create_engine() builds a " + "Qwen35MoEEngine and returns the generic LLMEngine / LLMSession surface."; + + ::executorch::runtime::runtime_init(); + + // Same generic Engine/Session surface as the _llm_runner module. + pybind_wrappers::bind_engine_session_api(m); + + m.def( + "create_engine", + [](const std::string& model_path, + const std::string& tokenizer_path, + std::optional data_path, + bool cuda_graph) { + Qwen35MoEConfig config; + config.model_path = model_path; + config.tokenizer_path = tokenizer_path; + config.data_path = data_path.value_or(""); + config.cuda_graph = cuda_graph; + auto res = Qwen35MoEEngine::create(config); + if (!res.ok()) { + throw std::runtime_error("Failed to create Qwen35MoEEngine"); + } + return std::make_unique( + std::unique_ptr(std::move(res.get()))); + }, + py::arg("model_path"), + py::arg("tokenizer_path"), + py::arg("data_path") = py::none(), + py::arg("cuda_graph") = false, + "Load the Qwen3.5 MoE program once and return an LLMEngine."); +} diff --git a/examples/models/qwen3_5_moe/serve.py b/examples/models/qwen3_5_moe/serve.py new file mode 100644 index 00000000000..15f8df51fb5 --- /dev/null +++ b/examples/models/qwen3_5_moe/serve.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible HTTP server for Qwen3.5 MoE (process-isolated). + +This is the CONTROL PLANE only: FastAPI/uvicorn + OpenAI protocol, chat +templating, tool parsing, request validation. It runs NO CUDA model code. Model +execution lives in a separate worker subprocess (worker.py) that this process +talks to over JSONL on stdin/stdout. + +Why two processes: executing the AOTI CUDA model inside a live asyncio server +process segfaults in the int4 matmul (validated by elimination — not thread +affinity, GIL, signals, or executor offload; the trigger is CUDA execution while +a live asyncio loop is resident). Isolating CUDA in a plain (no-asyncio) worker +process is the reliable shape, and it still loads weights once. + +V1 constraints: + * serving_capacity == 1: one worker, one session; concurrent HTTP requests + queue (RunnerPool num_runners=1). + * prefix cache off (Qwen seek() is NotSupported). + * The control plane only does blocking pipe I/O on its executor thread (no + CUDA), which is safe under asyncio. + +Launch (LD_LIBRARY_PATH shim is forwarded to the worker for the CUDA blob): + + LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \\ + python -m executorch.examples.models.qwen3_5_moe.serve \\ + --model-path qwen35_moe_exports/model.pte \\ + --data-path qwen35_moe_exports/aoti_cuda_blob.ptd \\ + --tokenizer-path ~/models/Qwen3.5-35B-A3B/tokenizer.json \\ + --hf-tokenizer ~/models/Qwen3.5-35B-A3B \\ + --model-id qwen3.5-moe --no-think +""" + +import argparse +import json +import logging +import os +import subprocess +import sys +import threading + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.runner_pool import RunnerPool +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.tool_parsers import QwenFunctionCallDetector + +logger = logging.getLogger(__name__) + +_CAPACITY = {"max_physical_sessions_without_weight_duplication": 1} + + +class _WorkerStats: + __slots__ = ("num_prompt_tokens", "num_generated_tokens") + + def __init__(self, prompt_tokens: int, generated_tokens: int): + self.num_prompt_tokens = prompt_tokens + self.num_generated_tokens = generated_tokens + + +class WorkerRunner: + """Drives the model worker subprocess over JSONL, exposing the RunnerPool + generate() surface. One worker = one session; calls are serialized by a lock + (and by RunnerPool's single slot). The control plane never executes CUDA.""" + + def __init__(self, proc: subprocess.Popen): + self._proc = proc + self._lock = threading.Lock() + + def reset(self) -> None: + # The worker resets its session per request; nothing to do here. + pass + + def stop(self) -> None: + # Best-effort only: the worker request is synchronous and not + # interruptible mid-generation in V1. + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + request = { + "prompt": prompt, + "max_new_tokens": getattr(config, "max_new_tokens", -1), + "temperature": getattr(config, "temperature", 0.0), + } + with self._lock: + if self._proc.poll() is not None: + raise RuntimeError( + f"Qwen worker exited (code {self._proc.returncode}); restart the server" + ) + try: + self._proc.stdin.write(json.dumps(request) + "\n") + self._proc.stdin.flush() + except (BrokenPipeError, ValueError) as e: + raise RuntimeError("Qwen worker stdin is closed") from e + + while True: + line = self._proc.stdout.readline() + if not line: + raise RuntimeError("Qwen worker exited mid-request") + msg = json.loads(line) + if "token" in msg: + if token_callback is not None: + token_callback(msg["token"]) + elif msg.get("done"): + if stats_callback is not None: + stats_callback( + _WorkerStats( + msg.get("prompt_tokens", 0), + msg.get("completion_tokens", 0), + ) + ) + return + elif "error" in msg: + raise RuntimeError(f"Qwen worker error: {msg['error']}") + + +def _spawn_worker(args) -> subprocess.Popen: + """Start the model worker subprocess and block until it reports ready.""" + env = dict(os.environ) + conda = os.environ.get("CONDA_PREFIX") + if conda: + # The AOTI CUDA blob needs the conda libstdc++; forward it to the worker. + env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "") + cmd = [ + sys.executable, + "-m", + "executorch.examples.models.qwen3_5_moe.worker", + "--model-path", + args.model_path, + "--tokenizer-path", + args.tokenizer_path, + "--hf-tokenizer", + args.hf_tokenizer, + ] + if args.data_path: + cmd += ["--data-path", args.data_path] + if args.ext_dir: + cmd += ["--ext-dir", args.ext_dir] + + logger.info("Starting Qwen worker subprocess (loads the model once)...") + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + bufsize=1, + env=env, + ) + line = proc.stdout.readline() + if not line: + raise SystemExit("Qwen worker failed to start (no output; check stderr).") + msg = json.loads(line) + if not msg.get("ready"): + raise SystemExit(f"Qwen worker did not report ready: {msg}") + logger.info("Qwen worker ready; serving single-slot, concurrent requests queue.") + return proc + + +def build_app_from_args(args): + """Construct the FastAPI app + the model worker. Returns (app, model_id).""" + default_template_kwargs = {"enable_thinking": False} if args.no_think else None + template = ChatTemplate( + args.hf_tokenizer, default_template_kwargs=default_template_kwargs + ) + + proc = _spawn_worker(args) + worker_runner = WorkerRunner(proc) + + # tokenizer=None -> prefix cache disabled (Qwen seek() is NotSupported). + # serving_capacity passed so the factory path is clamped to 1. + pool = RunnerPool( + runner_factory=lambda: worker_runner, + num_runners=1, + tokenizer=None, + serving_capacity=_CAPACITY, + ) + serving = ServingChat( + pool, + template, + args.model_id, + max_context=args.max_context, + # Qwen3.5-MoE emits the XML tool format. + tool_detector_cls=QwenFunctionCallDetector, + ) + + from executorch.extension.llm.server.python.server import build_app + + app = build_app(serving, args.model_id) + + @app.on_event("shutdown") + def _stop_worker(): + if proc.poll() is None: + proc.terminate() + + return app, args.model_id + + +def main() -> None: + p = argparse.ArgumentParser( + description="OpenAI-compatible LLM server for Qwen3.5 MoE (process-isolated, V1)" + ) + p.add_argument("--model-path", required=True, help="Path to the .pte model") + p.add_argument( + "--data-path", default=None, help="Path to the .ptd CUDA delegate blob" + ) + p.add_argument( + "--tokenizer-path", required=True, help="Path to the HuggingFace tokenizer.json" + ) + p.add_argument( + "--hf-tokenizer", + required=True, + help="HF tokenizer id/dir for the model's chat template + encoding", + ) + p.add_argument("--model-id", default="qwen3.5-moe") + p.add_argument("--host", default="127.0.0.1") + p.add_argument("--port", type=int, default=8000) + p.add_argument( + "--max-context", + type=int, + default=None, + help="Context window; prompts exceeding it are rejected with 400.", + ) + p.add_argument( + "--no-think", + action="store_true", + help="Default reasoning off (enable_thinking=False).", + ) + p.add_argument( + "--num-runners", + type=int, + default=1, + help="V1 supports 1 only (serving_capacity=1).", + ) + p.add_argument( + "--ext-dir", + default=None, + help="Directory with the built _qwen35_moe module (for the worker).", + ) + args = p.parse_args() + logging.basicConfig(level=logging.INFO) + + if args.num_runners != 1: + p.error( + "Qwen3.5 MoE V1 is single-slot: serving_capacity=1. One worker serves " + "one session; concurrent requests queue." + ) + + app, _ = build_app_from_args(args) + + import uvicorn + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5_moe/test_serve.py b/examples/models/qwen3_5_moe/test_serve.py new file mode 100644 index 00000000000..610f4f45a52 --- /dev/null +++ b/examples/models/qwen3_5_moe/test_serve.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the Qwen3.5 MoE process-isolated OpenAI launcher (serve.py). + +Hermetic: no model, GPU, or worker subprocess. Covers layering (Qwen stays an +example; the control plane runs no CUDA), the single-slot CLI guard, and the +WorkerRunner JSONL protocol against a fake worker. The live HTTP smoke test is +documented in README.md and run on a CUDA box. +""" + +import json +import pathlib + +import pytest + +from executorch.examples.models.qwen3_5_moe import serve + +_HERE = pathlib.Path(serve.__file__).resolve().parent +_REPO_ROOT = _HERE.parents[2] # qwen3_5_moe -> models -> examples -> repo root + + +# --- Layering --------------------------------------------------------------- + + +def test_generic_runner_pybind_has_no_qwen_include(): + src = (_REPO_ROOT / "extension/llm/runner/pybindings.cpp").read_text() + assert "qwen3_5_moe" not in src and "qwen35_moe" not in src + + +def test_generic_server_does_not_import_qwen(): + server_dir = _REPO_ROOT / "extension/llm/server" + offenders = [ + p + for p in server_dir.rglob("*.py") + if "qwen3_5_moe" in p.read_text() or "_qwen35_moe" in p.read_text() + ] + assert offenders == [], f"generic server must not reference Qwen: {offenders}" + + +def test_control_plane_runs_no_cuda_model(): + # serve.py is the control plane: it must NOT construct the CUDA engine; only + # the worker (worker.py) calls create_engine on the model module. + assert "create_engine" not in (_HERE / "serve.py").read_text() + assert "create_engine" in (_HERE / "worker.py").read_text() + + +# --- WorkerRunner JSONL protocol (fake worker) ------------------------------ + + +class _FakeStdin: + def __init__(self): + self.written = [] + + def write(self, s): + self.written.append(s) + + def flush(self): + pass + + +class _FakeStdout: + def __init__(self, lines): + self._lines = list(lines) + + def readline(self): + return self._lines.pop(0) if self._lines else "" + + +class _FakeProc: + def __init__(self, lines): + self.stdin = _FakeStdin() + self.stdout = _FakeStdout(lines) + self.returncode = None + + def poll(self): + return None + + +class _Cfg: + __slots__ = ("max_new_tokens", "temperature") + + def __init__(self, max_new_tokens=16, temperature=0.0): + self.max_new_tokens = max_new_tokens + self.temperature = temperature + + +def test_worker_runner_streams_tokens_and_stats(): + proc = _FakeProc( + [ + '{"token": "Hello"}\n', + '{"token": " world"}\n', + '{"done": true, "prompt_tokens": 5, "completion_tokens": 2}\n', + ] + ) + wr = serve.WorkerRunner(proc) + out, stats = [], {} + wr.generate( + "p", + _Cfg(temperature=0.7), + token_callback=out.append, + stats_callback=lambda s: stats.update( + p=s.num_prompt_tokens, g=s.num_generated_tokens + ), + ) + assert out == ["Hello", " world"] + assert stats == {"p": 5, "g": 2} + sent = json.loads(proc.stdin.written[0]) + assert sent["prompt"] == "p" and sent["temperature"] == 0.7 + + +def test_worker_runner_error_raises(): + proc = _FakeProc(['{"error": "boom"}\n']) + with pytest.raises(RuntimeError, match="boom"): + serve.WorkerRunner(proc).generate("p", _Cfg(), token_callback=lambda t: None) + + +def test_worker_runner_exit_midrequest_raises(): + proc = _FakeProc([]) # readline() -> "" means the worker exited + with pytest.raises(RuntimeError, match="exited"): + serve.WorkerRunner(proc).generate("p", _Cfg()) + + +# --- CLI guard -------------------------------------------------------------- + + +def test_rejects_multiple_runners(monkeypatch): + import sys + + monkeypatch.setattr( + sys, + "argv", + [ + "serve.py", + "--model-path", + "m.pte", + "--tokenizer-path", + "t.json", + "--hf-tokenizer", + "hf", + "--num-runners", + "2", + ], + ) + with pytest.raises(SystemExit): + serve.main() diff --git a/examples/models/qwen3_5_moe/worker.py b/examples/models/qwen3_5_moe/worker.py new file mode 100644 index 00000000000..21f802c3bf7 --- /dev/null +++ b/examples/models/qwen3_5_moe/worker.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Process-isolated Qwen3.5 MoE generation worker. + +Runs the CUDA/AOTI model in a dedicated process with NO asyncio HTTP server. +The OpenAI control plane (serve.py) talks to this worker over JSONL on +stdin/stdout. This isolation is required: executing the AOTI CUDA model inside a +live asyncio server process segfaults in the int4 matmul (validated). Here the +model runs like the CLI — a plain synchronous loop — which is reliable. + +Protocol (one JSON object per line): + worker -> stdout, once at startup: {"ready": true} + serve -> stdin, per request: {"prompt": str, "max_new_tokens": int, + "temperature": float} + worker -> stdout, per request: {"token": str} * (streamed) + {"done": true, "prompt_tokens": int, + "completion_tokens": int} + or {"error": str} + +stdout carries ONLY protocol JSON; all logs go to stderr. One request at a time. +""" + +import argparse +import importlib +import json +import sys +from pathlib import Path + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.session_generate import ( + SessionGenerateAdapter, +) + + +def _load_ext(explicit_dir): + candidates = [] + if explicit_dir: + candidates.append(Path(explicit_dir)) + repo_root = Path(__file__).resolve().parents[3] + candidates.append(repo_root / "cmake-out" / "examples" / "models" / "qwen3_5_moe") + for d in candidates: + if d.is_dir() and str(d) not in sys.path: + sys.path.insert(0, str(d)) + return importlib.import_module("_qwen35_moe") + + +class _Config: + __slots__ = ("max_new_tokens", "temperature") + + def __init__(self, max_new_tokens, temperature): + self.max_new_tokens = max_new_tokens + self.temperature = temperature + + +def _emit(obj): + sys.stdout.write(json.dumps(obj)) + sys.stdout.write("\n") + sys.stdout.flush() + + +def main() -> None: + p = argparse.ArgumentParser(description="Qwen3.5 MoE generation worker") + p.add_argument("--model-path", required=True) + p.add_argument("--data-path", default=None) + p.add_argument("--tokenizer-path", required=True) + p.add_argument("--hf-tokenizer", required=True) + p.add_argument("--ext-dir", default=None) + args = p.parse_args() + + ext = _load_ext(args.ext_dir) + # HF tokenizer for prompt encoding (the model's own template tokenizer). + hf_tokenizer = ChatTemplate(args.hf_tokenizer).tokenizer() + engine = ext.create_engine( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + data_path=args.data_path, + cuda_graph=False, + ) + adapter = SessionGenerateAdapter(engine.create_session(), hf_tokenizer) + + _emit({"ready": True}) + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + _handle_request(adapter, json.loads(line)) + except Exception as e: # noqa: BLE001 - report to the control plane + _emit({"error": repr(e)}) + + +def _handle_request(adapter, req) -> None: + """Run one generation request and stream the JSONL result. In a function (not + the read loop) so the callbacks don't close over loop variables.""" + config = _Config( + int(req.get("max_new_tokens") or -1), + float(req.get("temperature", 0.0)), + ) + stats = {"prompt": 0, "gen": 0} + + def stats_cb(s): + stats["prompt"] = s.num_prompt_tokens + stats["gen"] = s.num_generated_tokens + + adapter.reset() + adapter.generate( + req["prompt"], + config, + token_callback=lambda t: _emit({"token": t}), + stats_callback=stats_cb, + ) + _emit( + { + "done": True, + "prompt_tokens": stats["prompt"], + "completion_tokens": stats["gen"], + } + ) + + +if __name__ == "__main__": + main() From cf4f9955eca48ca2c64a7f859c7bb2b48c3d8937 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 4 Jun 2026 15:24:43 -0700 Subject: [PATCH 2/2] [UPDATE] Update [ghstack-poisoned] --- .../models/qwen3_5_moe/qwen35_moe_engine.cpp | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 5778d193c4f..2ebc62b40e4 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -43,7 +43,7 @@ namespace { // session logic below stays backend-agnostic. // --------------------------------------------------------------------------- -uint64_t read_sampled_token( +Result read_sampled_token( const executorch::aten::Tensor& output, float temperature) { #ifdef EXECUTORCH_BUILD_CUDA @@ -56,8 +56,10 @@ uint64_t read_sampled_token( if (on_device) { if (cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost) != cudaSuccess) { + // Don't fabricate token id 0 (a valid token) on a copy failure — that is + // silent corruption. Surface it so the caller aborts the request. ET_LOG(Error, "read_sampled_token: cudaMemcpy D2H failed"); - return 0; + return Error::Internal; } } else { std::memcpy(&val, ptr, sizeof(float)); @@ -164,11 +166,15 @@ class Qwen35MoESession : public LLMSession { } const int64_t T = static_cast(tokens.size()); const auto ctx_it = metadata_.find(kMaxContextLen); - if (ctx_it != metadata_.end() && pos_ + T > ctx_it->second) { + // Require room for at least one generated token: after prefill, pos_ == T + // and decode_one() forwards the first token at pos_, which must be < the + // context length. Rejecting pos_ + T == max_context (not just > it) keeps a + // full prompt from reaching decode_one with no room to step. + if (ctx_it != metadata_.end() && pos_ + T >= ctx_it->second) { ET_LOG( Error, - "prefill_tokens would exceed context capacity (pos %" PRId64 - " + %" PRId64 " > %" PRId64 ")", + "prefill_tokens would leave no room to generate (pos %" PRId64 + " + %" PRId64 " >= max_context %" PRId64 ")", pos_, T, ctx_it->second); @@ -202,7 +208,9 @@ class Qwen35MoESession : public LLMSession { #endif auto res = module_->execute(method, inputs); ET_CHECK_OK_OR_RETURN_ERROR(res.error()); - pending_ = read_sampled_token(res.get()[0].toTensor(), first_token_temp); + auto sampled = read_sampled_token(res.get()[0].toTensor(), first_token_temp); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); prev_decode_token_.reset(); pos_ += T; // the prompt tokens are now resident in KV/state #ifdef EXECUTORCH_BUILD_CUDA @@ -255,6 +263,21 @@ class Qwen35MoESession : public LLMSession { token, std::move(text_piece), is_eos, /*is_terminal=*/true}; } + // Only a NON-EOS, non-stopped token is forwarded (made resident at pos_), so + // the capacity check belongs here — after the short-circuit, so a final EOS + // is still emitted when state is exactly full. Without it, decode would + // write KV/recurrent state past the context window. + const auto ctx_it = metadata_.find(kMaxContextLen); + if (ctx_it != metadata_.end()) { + ET_CHECK_OR_RETURN_ERROR( + pos_ < ctx_it->second, + InvalidArgument, + "decode_one would exceed context capacity: pos_ %" PRId64 + " >= max_context %" PRId64, + pos_, + ctx_it->second); + } + // Forward `token` at pos_ through the decode method to get the next pending // token. Update the persistent buffers in place (stable addresses). decode_token_data_[0] = static_cast(token); @@ -268,7 +291,9 @@ class Qwen35MoESession : public LLMSession { #endif auto res = module_->execute("decode", inputs); ET_CHECK_OK_OR_RETURN_ERROR(res.error()); - pending_ = read_sampled_token(res.get()[0].toTensor(), temperature_); + auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature_); + ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); + pending_ = sampled.get(); prev_decode_token_ = token; pos_ += 1; return DecodeResult{