diff --git a/docs/plan_chat_template_input_workarounds.md b/docs/plan_chat_template_input_workarounds.md new file mode 100644 index 0000000000..8476d4bd01 --- /dev/null +++ b/docs/plan_chat_template_input_workarounds.md @@ -0,0 +1,351 @@ +# Plan: Chat Template Input Workarounds & Auto-Detection + +## Problem Statement + +OVMS currently requires manual configuration of `tool_parser` and `reasoning_parser` via the MediaPipe graph proto. There is no: +- Automatic detection of model/template capabilities +- Input transformation before template application (e.g. Gemma requiring object arguments) +- Auto-detection of which tool/reasoning parser to use based on template content + +Both **llama.cpp** and **minja** solve these problems via "dry run" probing and model-specific workarounds. This plan adapts those techniques for OVMS. + +--- + +## Background: How llama.cpp & minja Do It + +### llama.cpp approach +1. **Needle-based dry runs** — renders template with probe data, tracks which fields are accessed (via `stats.used`) +2. **String pattern matching** — searches template source for unique markers (`<|tool_call>call:'`, `[TOOL_CALLS]`, `<|channel|>`, etc.) to identify model family +3. **Workarounds applied pre-render** — `func_args_not_string`, `requires_non_null_content`, `system_message_not_supported`, `map_developer_role_to_system`, `convert_tool_responses_gemma4` +4. **Autoparser** — differential analysis (render twice, diff output) to detect reasoning/tool-call format + +### minja approach +1. **`try_raw_render` + needles** — renders template with sentinel strings, checks if they appear in output +2. **``** — detects coder-style XML parameter templates (Qwen3-Coder) +3. **Capability struct** — `chat_template_caps` populated at construction time: `supports_tools`, `requires_object_arguments`, `requires_typed_content`, etc. +4. **Polyfills** — automatic fallbacks when template lacks native support (inject tool definitions into system prompt, merge system into user, etc.) + +--- + +## Current OVMS Architecture (relevant parts) + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Request Flow │ +│ │ +│ loadRequest → parseRequest → prepareInputs → scheduleExecution │ +│ │ │ │ +│ parseMessages() applyChatTemplate() │ +│ parseTools() (Jinja or GenAI) │ +│ │ │ │ +│ ensureArguments ─────┼──── NO workarounds today │ +│ InToolCalls() │ │ +│ ▼ │ +│ GenerationConfigBuilder │ +│ OutputParser (tool + reasoning) │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +| Servable Type | Template Applicator | Notes | +|---|---|---| +| LM (legacy) | Python Jinja (primary) / GenAI C++ (fallback) | | +| LM_CB | Python Jinja (primary) / GenAI C++ (fallback) | Main production path | +| VLM (legacy) | GenAI C++ only | | +| VLM_CB | GenAI C++ only | | + +Key files: +- `src/llm/servable.cpp` — `prepareInputs()` calls template applicator +- `src/llm/py_jinja_template_processor.cpp` — Python Jinja path +- `src/llm/io_processing/output_parser.hpp` — tool/reasoning parsing +- `src/llm/io_processing/generation_config_builder.hpp` — stop strings & guided generation +- `src/llm/language_model/continuous_batching/servable_initializer.cpp` — reads `tool_parser`/`reasoning_parser` from proto + +--- + +## Proposed Design + +### New Component: `ChatTemplateAnalyzer` + +A singleton-per-servable object created at initialization time that: +1. Reads the chat template source +2. Detects template capabilities (what the template supports/requires) +3. Determines which tool/reasoning parser matches the template +4. Provides workaround functions to transform inputs before template application + +``` +┌───────────────────────────────────────────────────────────────────────────┐ +│ Initialization (servable_initializer) │ +│ │ +│ Load chat template source ──► ChatTemplateAnalyzer │ +│ │ │ +│ ├── detectCaps() (dry-run probing) │ +│ ├── detectParsers() (pattern matching) │ +│ └── store in GenAiServableProperties │ +└───────────────────────────────────────────────────────────────────────────┘ + +┌───────────────────────────────────────────────────────────────────────────┐ +│ Request time (prepareInputs) │ +│ │ +│ messages/tools ──► InputWorkarounds::apply(caps, messages, tools) ───► │ +│ │ │ +│ ├── func_args_to_object() │ +│ ├── ensure_non_null_content() │ +│ ├── convert_tool_responses_gemma4() │ +│ ├── convert_typed_content() │ +│ └── (future workarounds) │ +│ │ +│ ──► applyChatTemplate(modified messages) │ +└───────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Detailed Design + +### Phase 1: Template Capability Detection (`ChatTemplateCaps`) + +#### 1.1 Data Structure + +```cpp +// src/llm/chat_template_caps.hpp + +struct ChatTemplateCaps { + bool supports_system_role = true; + bool supports_tools = false; + bool supports_tool_calls = false; + bool supports_tool_responses = false; + bool requires_object_arguments = false; // Gemma: args as dict not string + bool requires_non_null_content = false; // tool_call messages need content="" + bool requires_typed_content = false; // content must be [{type:"text",...}] + bool supports_parallel_tool_calls = false; + bool supports_tool_call_id = false; +}; +``` + +#### 1.2 Detection Strategy (two-tier) + +**Tier 1: Pattern matching on template source** (fast, no execution needed) +- Search for known unique strings to identify model family +- Maps directly to a known `ChatTemplateCaps` preset + parser name + +| Pattern in template source | Model family | Tool parser | Reasoning parser | +|---|---|---|---| +| `<\|python_tag\|>` | Llama 3.x | `llama3` | — | +| `` + `` | Hermes/Qwen | `hermes3` | — | +| `<\|tool_call\|>` + no `` | Mistral | `mistral` | — | +| `<\|tool_call\|>` + `function_calls` | DeepSeek | (future) | — | +| `<\|channel\|>` | GPT-OSS | `gptoss` | `gptoss` | +| `'<\|tool_call>call:'` | Gemma 4 | `gemma4` | `gemma4` | +| `` | Phi-4 | `phi4` | — | +| `content.split('')` OR `` in generation prompt | Qwen3/DeepSeek-R1-distill | — | `qwen3` | +| `` + `[TOOL_RESULTS]` | Devstral | `devstral` | — | +| `<\|assistant_tool_call\|>` | LFM2 | `lfm2` | — | + +**Tier 2: Dry-run probing** (when pattern matching is inconclusive) +- Only for the Python Jinja path (we control execution) +- Render template with needle-containing test messages +- Check output for presence/absence of needles + +> **Decision**: For initial implementation, Tier 1 (pattern matching) covers all currently supported parsers. Tier 2 can be added later for unknown templates. + +#### 1.3 Integration Points + +- **Initialization**: `ChatTemplateAnalyzer` runs in `servable_initializer` after loading the template source +- **Storage**: Results stored in `GenAiServableProperties` as `ChatTemplateCaps caps` + auto-detected `toolParserName`/`reasoningParserName` +- **Override**: If user explicitly sets `tool_parser`/`reasoning_parser` in proto, those take precedence over auto-detection + +--- + +### Phase 2: Input Workarounds (`InputWorkarounds`) + +#### 2.1 Workaround Functions + +```cpp +// src/llm/input_workarounds.hpp + +namespace ovms { +namespace input_workarounds { + +// Convert tool_call arguments from JSON string to parsed object +// Triggered by: caps.requires_object_arguments +void funcArgsToObject(rapidjson::Document& doc); + +// Ensure tool_call messages have non-null content field +// Triggered by: caps.requires_non_null_content +void ensureNonNullContent(rapidjson::Document& doc); + +// Restructure tool response messages for Gemma4 format +// Triggered by: detected model == gemma4 +void convertToolResponsesGemma4(rapidjson::Document& doc); + +// Convert string content to typed content array [{type:"text", text:"..."}] +// Triggered by: caps.requires_typed_content +void convertToTypedContent(rapidjson::Document& doc); + +// Apply all relevant workarounds based on caps +void applyAll(const ChatTemplateCaps& caps, const std::string& modelFamily, + rapidjson::Document& doc); + +} // namespace input_workarounds +} // namespace ovms +``` + +#### 2.2 Where Workarounds Are Applied + +Two integration points (mirroring the two applicator paths): + +**Python Jinja path** (`PyJinjaTemplateProcessor::applyChatTemplate`): +- Workarounds modify the `processedJson` / `requestBody` JSON **before** it's passed to the Python template renderer +- The JSON already contains `messages` and `tools` arrays + +**GenAI C++ path** (`tokenizer.apply_chat_template`): +- Workarounds modify the `ChatHistory` and/or re-serialize tool_calls arguments **before** calling the tokenizer +- May need a helper to serialize/deserialize between `ChatHistory` and a mutable JSON representation + +#### 2.3 Call Site + +In `GenAiServable::prepareInputs()` (and VLM variants), **after** `parseRequest()` but **before** template application: + +```cpp +// After parseRequest populates chatHistory/processedJson: +auto& caps = getProperties()->chatTemplateCaps; +auto& modelFamily = getProperties()->detectedModelFamily; + +#if (PYTHON_DISABLE == 0) + // Modify the JSON document that will be passed to Python Jinja + input_workarounds::applyAll(caps, modelFamily, executionContext->apiHandler->getMutableProcessedJson()); +#else + // Modify ChatHistory in-place for GenAI path + input_workarounds::applyAllToHistory(caps, modelFamily, executionContext->apiHandler->getChatHistory()); +#endif +``` + +--- + +### Phase 3: Auto-Detection of Tool/Reasoning Parsers + +#### 3.1 Goal + +Eliminate the need for users to manually specify `tool_parser` and `reasoning_parser` in the graph proto for common models. + +#### 3.2 Implementation + +`ChatTemplateAnalyzer::detectParsers()` returns: +- `std::optional detectedToolParser` +- `std::optional detectedReasoningParser` + +In `servable_initializer`: +```cpp +if (!nodeOptions.has_tool_parser()) { + // Auto-detect from template + auto detected = analyzer.detectParsers(templateSource); + if (detected.toolParser.has_value()) { + properties->toolParserName = detected.toolParser.value(); + SPDLOG_LOGGER_INFO(logger, "Auto-detected tool_parser: {}", properties->toolParserName); + } +} +// Same for reasoning_parser +``` + +#### 3.3 Logging & Transparency + +- Log at INFO level when auto-detection fires and what was detected +- Log at WARNING when auto-detection fails (unknown template) and no parser was configured +- Model card / graph node description should still document the manual override + +--- + +### Phase 4: Future — Advanced Probing (Tier 2 Dry-Run) + +For templates that don't match any known pattern: + +1. **Python path**: Execute the template with needle messages via `PyJinjaTemplateProcessor`, check output +2. **GenAI path**: Use `tokenizer.apply_chat_template()` with probe `ChatHistory`, check output +3. Populate `ChatTemplateCaps` from the probe results +4. Optionally: generate tool-call example by differential rendering (like minja) + +This is deferred because: +- All currently supported models have recognizable template patterns +- Dry-run adds initialization latency +- GenAI C++ tokenizer doesn't expose field-access tracking (unlike minja's stat-based approach) + +--- + +## File Structure + +``` +src/llm/ +├── chat_template_caps.hpp # ChatTemplateCaps struct +├── chat_template_analyzer.hpp # ChatTemplateAnalyzer class declaration +├── chat_template_analyzer.cpp # Pattern matching + detection logic +├── input_workarounds.hpp # Input transformation functions +├── input_workarounds.cpp # Implementations +└── io_processing/ + └── (existing parsers unchanged) +``` + +--- + +## Implementation Order + +| Step | Description | Effort | Dependencies | +|------|-------------|--------|--------------| +| 1 | Define `ChatTemplateCaps` struct | S | None | +| 2 | Implement `ChatTemplateAnalyzer` with Tier 1 pattern matching | M | Step 1 | +| 3 | Integrate auto-detection into all 4 servable initializers | M | Step 2 | +| 4 | Implement `input_workarounds::funcArgsToObject` (Gemma case) | S | Step 1 | +| 5 | Implement `input_workarounds::ensureNonNullContent` | S | Step 1 | +| 6 | Wire workarounds into `prepareInputs()` for both Jinja & GenAI paths | M | Steps 4-5 | +| 7 | Add unit tests for analyzer + workarounds | M | Steps 2-6 | +| 8 | (Future) Tier 2 dry-run probing | L | Step 6 | +| 9 | (Future) Auto-generate tool-call example for unsupported templates | L | Step 8 | + +--- + +## Interaction with Existing Code + +### What changes + +| Component | Change | +|---|---| +| `GenAiServableProperties` | Add `ChatTemplateCaps caps`, `std::string detectedModelFamily` | +| `servable_initializer` (all 4 variants) | Call `ChatTemplateAnalyzer` after loading template; use auto-detected parser if none configured | +| `GenAiServable::prepareInputs()` | Call `input_workarounds::applyAll()` before template application | +| `VLM servable::prepareInputs()` | Same workaround call | +| `PyJinjaTemplateProcessor::applyChatTemplate()` | Receives already-transformed JSON (no change to Python code) | + +### What does NOT change + +- Output parsers (`OutputParser`, all model-specific parsers) +- `GenerationConfigBuilder` (stop strings, guided generation) +- Python Jinja template rendering logic +- GenAI tokenizer API usage +- Existing proto fields (`tool_parser`, `reasoning_parser`) — they become optional overrides + +--- + +## Risks & Mitigations + +| Risk | Mitigation | +|---|---| +| Pattern matching gives false positive for custom/fine-tuned templates | Manual proto override always takes precedence; log detection result | +| Template source not available at initialization (e.g. embedded in tokenizer binary) | GenAI tokenizer exposes chat template string via `get_chat_template()`; use that | +| Workarounds break valid inputs | Apply workarounds only when `caps` indicates the template requires them; add tests for each workaround with real template examples | +| Performance overhead from JSON manipulation | Workarounds operate on already-parsed `rapidjson::Document`; negligible vs. LLM inference time | +| Two code paths (Jinja vs GenAI) need consistent workarounds | Shared `input_workarounds` module with path-specific entry points; test both paths | + +--- + +## Open Questions + +1. **Should auto-detection be opt-in or opt-out?** Proposed: opt-out (auto-detect by default, explicit proto value overrides). If a user sets `tool_parser: ""` (empty string), disable tool parsing entirely. + +2. **Where to get template source for GenAI-only path (VLM)?** Use `tokenizer.get_chat_template()` or read `chat_template.jinja` / `tokenizer_config.json` directly from the model directory. + +3. **Should workarounds also apply to the Responses API path?** Yes — same template, same requirements. + +4. **Should we log a warning if auto-detection finds a parser but the user configured a different one?** Yes, at DEBUG level — the user's choice is intentional. + +5. **Gemma4 `convert_tool_responses_gemma4` — is this needed for OVMS?** Depends on whether the Gemma4 template in OpenVINO GenAI handles tool responses natively. Needs testing. diff --git a/src/BUILD b/src/BUILD index 47510cd54d..fefd7f4b18 100644 --- a/src/BUILD +++ b/src/BUILD @@ -2510,6 +2510,7 @@ cc_test( "//src/llm:genai_servables", "//src/llm:output_parsers", ":test_llm_output_parser_tests", + ":test_chat_template_workarounds", "//src/test/mediapipe/calculators:mediapipe_test_calculators", "//src/test/mediapipe/calculators:dependency_free_http_test_calculators", "@mediapipe//mediapipe/calculators/ovms:ovms_calculator", @@ -3035,6 +3036,27 @@ cc_library( local_defines = COMMON_LOCAL_DEFINES, ) +cc_library( + name = "test_chat_template_workarounds", + linkstatic = 1, + alwayslink = True, + srcs = [ + "test/llm/chat_template_analyzer_test.cpp", + "test/llm/input_workarounds_test.cpp", + "test/llm/chat_template_end_to_end_test.cpp", + ], + deps = [ + "@com_google_googletest//:gtest", + "//src/llm:chat_template_analyzer", + "//src/llm:chat_template_probe", + "//src/llm:input_workarounds", + "//third_party:genai", + ":test_platform_utils", + ], + copts = COPTS_TESTS, + local_defines = COMMON_LOCAL_DEFINES, +) + ovms_cc_library( name = "capimodule", hdrs = ["capi_frontend/capimodule.hpp"], diff --git a/src/llm/BUILD b/src/llm/BUILD index 397069b9de..0ac3d4a270 100644 --- a/src/llm/BUILD +++ b/src/llm/BUILD @@ -122,6 +122,41 @@ ovms_cc_library( ], visibility = ["//visibility:public"], ) + +ovms_cc_library( + name = "chat_template_analyzer", + hdrs = ["chat_template_caps.hpp", + "chat_template_analyzer.hpp"], + srcs = ["chat_template_analyzer.cpp"], + deps = [], + visibility = ["//visibility:public"], +) + +ovms_cc_library( + name = "chat_template_probe", + hdrs = ["chat_template_probe.hpp", + "chat_template_caps.hpp"], + srcs = ["chat_template_probe.cpp"], + deps = [ + "//third_party:genai", + "//src:libovmslogging", + ], + visibility = ["//visibility:public"], +) + +ovms_cc_library( + name = "input_workarounds", + hdrs = ["input_workarounds.hpp", + "chat_template_caps.hpp"], + srcs = ["input_workarounds.cpp"], + deps = [ + "@com_github_tencent_rapidjson//:rapidjson", + "//third_party:genai", + "//src:libovmslogging", + ], + visibility = ["//visibility:public"], +) + ovms_cc_library( name = "partial_json_builder", hdrs = ["io_processing/partial_json_builder.hpp"], @@ -344,6 +379,9 @@ ovms_cc_library( ":openai_completions_api_handler", ":openai_responses_handler", ":generation_config_builders", + ":chat_template_analyzer", + ":chat_template_probe", + ":input_workarounds", "//src:httppayload", "//src:libhttpclientconnection", "//src:sse_utils", diff --git a/src/llm/chat_template_analyzer.cpp b/src/llm/chat_template_analyzer.cpp new file mode 100644 index 0000000000..c914a5eada --- /dev/null +++ b/src/llm/chat_template_analyzer.cpp @@ -0,0 +1,145 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "chat_template_analyzer.hpp" + +#include + +namespace ovms { + +static bool contains(const std::string& haystack, const std::string& needle) { + return haystack.find(needle) != std::string::npos; +} + + +// TODO: remove comments before analysis +// TODO: expect GenAI to fix bug + dry-runs on separate threads? +ChatTemplateAnalysisResult ChatTemplateAnalyzer::analyze(const std::string& templateSource) { + ChatTemplateAnalysisResult result; + if (templateSource.empty()) { + return result; + } + + // GPT-OSS detection — must be before other checks as it has a unique marker + if (contains(templateSource, "<|channel|>")) { + result.detectedModelFamily = "gptoss"; + result.detectedToolParser = "gptoss"; + result.detectedReasoningParser = "gptoss"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + return result; + } + + // Gemma4 detection + if (contains(templateSource, "'<|tool_call>call:'") || contains(templateSource, "<|tool_call>call:")) { + result.detectedModelFamily = "gemma4"; + result.detectedToolParser = "gemma4"; + result.detectedReasoningParser = "gemma4"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + result.caps.requiresObjectArguments = true; + return result; + } + + // Qwen3-Coder detection — uses ") && contains(templateSource, "") || contains(templateSource, "")) { + result.detectedReasoningParser = "qwen3"; + } + return result; + } + + // LFM2 detection + if (contains(templateSource, "<|assistant_tool_call|>") || contains(templateSource, "<|tool_call_start|>")) { + result.detectedModelFamily = "lfm2"; + result.detectedToolParser = "lfm2"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + return result; + } + + // Phi-4 detection + if (contains(templateSource, "<|tool\xe2\x96\x81" "call\xe2\x96\x81" "begin|>")) { // <|tool▁call▁begin|> with Unicode ▁ (U+2581) + result.detectedModelFamily = "phi4"; + result.detectedToolParser = "phi4"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + return result; + } + + // Devstral detection — uses [TOOL_CALLS] with [TOOL_RESULTS] + if (contains(templateSource, "[TOOL_CALLS]") && contains(templateSource, "[TOOL_RESULTS]")) { + result.detectedModelFamily = "devstral"; + result.detectedToolParser = "devstral"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + return result; + } + + // Mistral detection — uses [TOOL_CALLS] without [TOOL_RESULTS] or uses [AVAILABLE_TOOLS] + if (contains(templateSource, "[TOOL_CALLS]") || (contains(templateSource, "[AVAILABLE_TOOLS]") && contains(templateSource, "[/AVAILABLE_TOOLS]"))) { + result.detectedModelFamily = "mistral"; + result.detectedToolParser = "mistral"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + return result; + } + + // Llama3 detection — <|python_tag|> + if (contains(templateSource, "<|python_tag|>")) { + result.detectedModelFamily = "llama3"; + result.detectedToolParser = "llama3"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + result.caps.requiresNonNullContent = true; + return result; + } + + // Hermes3/Qwen detection — / (without ") && contains(templateSource, "")) { + result.detectedModelFamily = "hermes3"; + result.detectedToolParser = "hermes3"; + result.caps.supportsToolCalls = true; + result.caps.supportsTools = true; + result.caps.supportsToolResponses = true; + // Check for reasoning support (think tags in Qwen3) + if (contains(templateSource, "") || contains(templateSource, "content.split('')")) { + result.detectedReasoningParser = "qwen3"; + } + return result; + } + + // Reasoning-only detection (no tool parser matched but template has reasoning tags) + if (contains(templateSource, "") || contains(templateSource, "content.split('')")) { + result.detectedReasoningParser = "qwen3"; + } + + return result; +} + +} // namespace ovms diff --git a/src/llm/chat_template_analyzer.hpp b/src/llm/chat_template_analyzer.hpp new file mode 100644 index 0000000000..ddcdefc0b0 --- /dev/null +++ b/src/llm/chat_template_analyzer.hpp @@ -0,0 +1,39 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include + +#include "chat_template_caps.hpp" + +namespace ovms { + +struct ChatTemplateAnalysisResult { + ChatTemplateCaps caps; + std::string detectedModelFamily; + std::optional detectedToolParser; + std::optional detectedReasoningParser; +}; + +class ChatTemplateAnalyzer { +public: + // Analyze the chat template source and return detected capabilities and parser names. + // Uses pattern matching on template source text (Tier 1 detection). + static ChatTemplateAnalysisResult analyze(const std::string& templateSource); +}; + +} // namespace ovms diff --git a/src/llm/chat_template_caps.hpp b/src/llm/chat_template_caps.hpp new file mode 100644 index 0000000000..13e8f427f9 --- /dev/null +++ b/src/llm/chat_template_caps.hpp @@ -0,0 +1,34 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include + +namespace ovms { + +struct ChatTemplateCaps { + bool supportsSystemRole = true; + bool supportsTools = false; + bool supportsToolCalls = false; + bool supportsToolResponses = false; + bool requiresObjectArguments = false; + bool requiresNonNullContent = false; + bool requiresTypedContent = false; + bool supportsParallelToolCalls = false; + bool supportsToolCallId = false; +}; + +} // namespace ovms diff --git a/src/llm/chat_template_probe.cpp b/src/llm/chat_template_probe.cpp new file mode 100644 index 0000000000..8cf889f6ac --- /dev/null +++ b/src/llm/chat_template_probe.cpp @@ -0,0 +1,115 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "chat_template_probe.hpp" + +#include +#include +#include +#include + +#include + +#include "../logging.hpp" + +namespace ovms { + +void probeChatTemplateCaps(ov::genai::Tokenizer& tokenizer, ChatTemplateCaps& caps) { + if (tokenizer.get_chat_template().empty()) { + return; + } + if (!caps.supportsToolCalls) { + return; + } + + auto probeStart = std::chrono::steady_clock::now(); + const std::string argNeedle = "probe_needle_xK9m"; + + auto strArgsFuture = std::async(std::launch::async, [&tokenizer, &argNeedle]() -> std::pair { + try { + ov::genai::ChatHistory history; + history.push_back(ov::genai::JsonContainer::from_json_string(R"({"role":"user","content":"Hello"})")); + history.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"probe_fn","arguments":"{\")" + argNeedle + R"(\":\"val\"}"}}]})")); + auto t0 = std::chrono::steady_clock::now(); + std::string output = tokenizer.apply_chat_template(history, false); + auto t1 = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe minja (string args): {} us", + std::chrono::duration_cast(t1 - t0).count()); + return {true, std::move(output)}; + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe minja (string args): exception: {}", e.what()); + return {false, ""}; + } catch (...) { + return {false, ""}; + } + }); + + auto objArgsFuture = std::async(std::launch::async, [&tokenizer, &argNeedle]() -> std::pair { + try { + ov::genai::ChatHistory history; + history.push_back(ov::genai::JsonContainer::from_json_string(R"({"role":"user","content":"Hello"})")); + history.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"probe_fn","arguments":{")" + argNeedle + R"(":"val"}}}]})")); + auto t0 = std::chrono::steady_clock::now(); + std::string output = tokenizer.apply_chat_template(history, false); + auto t1 = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe minja (object args): {} us", + std::chrono::duration_cast(t1 - t0).count()); + return {true, std::move(output)}; + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe minja (object args): exception: {}", e.what()); + return {false, ""}; + } catch (...) { + return {false, ""}; + } + }); + + auto [strOk, strOut] = strArgsFuture.get(); + auto [objOk, objOut] = objArgsFuture.get(); + + auto rendersNativeArgs = [&argNeedle](const std::string& output) -> bool { + return output.find("\"" + argNeedle + "\": ") != std::string::npos || // JSON key: "needle": + output.find("'" + argNeedle + "': ") != std::string::npos || // Python dict: 'needle': + output.find("") != std::string::npos || // Qwen3-Coder XML + output.find(argNeedle + ":<|") != std::string::npos; // Gemma4: needle:<| + }; + + bool strArgsRendersNative = strOk && rendersNativeArgs(strOut); + bool objArgsRendersNative = objOk && rendersNativeArgs(objOut); + + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe requiresObjectArguments: strRendersNative={}, objRendersNative={}", + strArgsRendersNative, objArgsRendersNative); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Dry-run probe strArgs output: {}", strOut); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Dry-run probe objArgs output: {}", objOut); + + if (strArgsRendersNative || objArgsRendersNative) { + bool probeResult = objArgsRendersNative; + if (probeResult != caps.requiresObjectArguments) { + SPDLOG_LOGGER_INFO(llm_calculator_logger, "Dry-run probe overrides requiresObjectArguments: {} -> {}", + caps.requiresObjectArguments, probeResult); + } + caps.requiresObjectArguments = probeResult; + } else { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe: template does not render tool_call arguments in native format, keeping string-matching result for requiresObjectArguments"); + } + + auto probeEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe completed in {} us. Final result: requiresObjectArguments={}", + std::chrono::duration_cast(probeEnd - probeStart).count(), + caps.requiresObjectArguments); +} + +} // namespace ovms diff --git a/src/llm/chat_template_probe.hpp b/src/llm/chat_template_probe.hpp new file mode 100644 index 0000000000..a2ee4c52ff --- /dev/null +++ b/src/llm/chat_template_probe.hpp @@ -0,0 +1,33 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include "chat_template_caps.hpp" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include +#pragma GCC diagnostic pop + +namespace ovms { + +// Probes a chat template by dry-running it with synthetic tool_call inputs +// to empirically detect whether the template requires object arguments. +// Updates caps.requiresObjectArguments based on probe results. +// Only performs probing if caps.supportsToolCalls is true. +void probeChatTemplateCaps(ov::genai::Tokenizer& tokenizer, ChatTemplateCaps& caps); + +} // namespace ovms diff --git a/src/llm/input_workarounds.cpp b/src/llm/input_workarounds.cpp new file mode 100644 index 0000000000..cd41274344 --- /dev/null +++ b/src/llm/input_workarounds.cpp @@ -0,0 +1,152 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "input_workarounds.hpp" + +#include + +#include + +#include "../logging.hpp" + +namespace ovms { +namespace input_workarounds { + +// --- JSON path implementations --- + +void funcArgsToObjectJson(rapidjson::Document& doc) { + if (!doc.HasMember("messages") || !doc["messages"].IsArray()) { + return; + } + auto& allocator = doc.GetAllocator(); + for (auto& message : doc["messages"].GetArray()) { + if (!message.IsObject() || !message.HasMember("tool_calls") || !message["tool_calls"].IsArray()) { + continue; + } + for (auto& toolCall : message["tool_calls"].GetArray()) { + if (!toolCall.IsObject() || !toolCall.HasMember("function") || !toolCall["function"].IsObject()) { + continue; + } + auto& function = toolCall["function"]; + if (!function.HasMember("arguments") || !function["arguments"].IsString()) { + continue; + } + const char* argsStr = function["arguments"].GetString(); + rapidjson::Document argsDoc; + argsDoc.Parse(argsStr); + if (argsDoc.HasParseError()) { + continue; + } + function["arguments"].CopyFrom(argsDoc, allocator); + } + } +} + +void ensureNonNullContentJson(rapidjson::Document& doc) { + if (!doc.HasMember("messages") || !doc["messages"].IsArray()) { + return; + } + auto& allocator = doc.GetAllocator(); + for (auto& message : doc["messages"].GetArray()) { + if (!message.IsObject() || !message.HasMember("tool_calls")) { + continue; + } + if (!message.HasMember("content")) { + message.AddMember("content", rapidjson::Value().SetString("", allocator), allocator); + } else if (message["content"].IsNull()) { + message["content"].SetString("", allocator); + } + } +} + +void applyToJson(const ChatTemplateCaps& caps, const std::string& modelFamily, rapidjson::Document& doc) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Applying input workarounds (JSON path) for model family '{}': " + "requiresObjectArguments={}, requiresNonNullContent={}", + modelFamily.empty() ? "(none)" : modelFamily, + caps.requiresObjectArguments, caps.requiresNonNullContent); + if (caps.requiresObjectArguments) { + funcArgsToObjectJson(doc); + } + if (caps.requiresNonNullContent) { + ensureNonNullContentJson(doc); + } +} + +// --- ChatHistory path implementations --- + +void funcArgsToObjectHistory(ov::genai::ChatHistory& chatHistory) { + for (size_t msgIdx = 0; msgIdx < chatHistory.size(); ++msgIdx) { + auto message = chatHistory[msgIdx]; + if (!message.contains("tool_calls")) { + continue; + } + auto toolCalls = message["tool_calls"]; + if (!toolCalls.is_array()) { + continue; + } + for (size_t i = 0; i < toolCalls.size(); ++i) { + auto toolCall = toolCalls[i]; + if (!toolCall.is_object() || !toolCall.contains("function")) { + continue; + } + auto function = toolCall["function"]; + if (!function.is_object() || !function.contains("arguments")) { + continue; + } + auto args = function["arguments"]; + if (!args.is_string()) { + continue; + } + std::string argsStr = args.get_string(); + // Parse and replace string arguments with the parsed JSON object + try { + function["arguments"] = ov::genai::JsonContainer::from_json_string(argsStr); + } catch (...) { + // If parsing fails, leave as-is + continue; + } + } + } +} + +void ensureNonNullContentHistory(ov::genai::ChatHistory& chatHistory) { + for (size_t msgIdx = 0; msgIdx < chatHistory.size(); ++msgIdx) { + auto message = chatHistory[msgIdx]; + if (!message.contains("tool_calls")) { + continue; + } + if (!message.contains("content")) { + message["content"] = ""; + } else if (message["content"].is_null()) { + message["content"] = ""; + } + } +} + +void applyToHistory(const ChatTemplateCaps& caps, const std::string& modelFamily, ov::genai::ChatHistory& chatHistory) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Applying input workarounds (ChatHistory path) for model family '{}': " + "requiresObjectArguments={}, requiresNonNullContent={}", + modelFamily.empty() ? "(none)" : modelFamily, + caps.requiresObjectArguments, caps.requiresNonNullContent); + if (caps.requiresObjectArguments) { + funcArgsToObjectHistory(chatHistory); + } + if (caps.requiresNonNullContent) { + ensureNonNullContentHistory(chatHistory); + } +} + +} // namespace input_workarounds +} // namespace ovms diff --git a/src/llm/input_workarounds.hpp b/src/llm/input_workarounds.hpp new file mode 100644 index 0000000000..c6171c8d30 --- /dev/null +++ b/src/llm/input_workarounds.hpp @@ -0,0 +1,60 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include + +#include +#include + +#include "chat_template_caps.hpp" + +namespace ovms { +namespace input_workarounds { + +// --- Individual workaround functions (JSON path) --- +// Each operates on the full request document containing "messages" array. +// Exposed individually for unit testing and for selective use during refactoring. + +// Convert tool_call arguments from JSON string to parsed JSON object. +// Models like Gemma require arguments as a dict/object, not a stringified JSON. +void funcArgsToObjectJson(rapidjson::Document& doc); + +// Ensure assistant messages with tool_calls have non-null content field. +// Some templates require content="" rather than content=null. +void ensureNonNullContentJson(rapidjson::Document& doc); + +// --- Individual workaround functions (ChatHistory path) --- +// Operates on ov::genai::ChatHistory for the GenAI C++ tokenizer path. + +// Convert tool_call arguments from string to object in ChatHistory. +void funcArgsToObjectHistory(ov::genai::ChatHistory& chatHistory); + +// Ensure assistant messages with tool_calls have non-null content in ChatHistory. +void ensureNonNullContentHistory(ov::genai::ChatHistory& chatHistory); + +// --- Aggregate application --- + +// Apply all relevant workarounds to the JSON document (Python Jinja path). +// Modifies the document in-place based on detected capabilities. +void applyToJson(const ChatTemplateCaps& caps, const std::string& modelFamily, rapidjson::Document& doc); + +// Apply all relevant workarounds to the ChatHistory (GenAI C++ tokenizer path). +// Modifies the chat history in-place based on detected capabilities. +void applyToHistory(const ChatTemplateCaps& caps, const std::string& modelFamily, ov::genai::ChatHistory& chatHistory); + +} // namespace input_workarounds +} // namespace ovms diff --git a/src/llm/llm_calculator.proto b/src/llm/llm_calculator.proto index ce252ea899..c3e38f3408 100644 --- a/src/llm/llm_calculator.proto +++ b/src/llm/llm_calculator.proto @@ -145,7 +145,6 @@ message LLMCalculatorOptions { MINJA = 0; // Use Python Jinja2 module for chat template processing. // For builds with Python, default for LLM pipelines, selectible for VLM pipelines. - // TODO(dkalinow): once we have server-side workaround, make it default for VLM as well JINJA = 1; } diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 0d934cce0a..91fe6c74dc 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -36,6 +36,7 @@ #include "../profiler.hpp" #include "apis/openai_completions.hpp" #include "apis/openai_responses.hpp" +#include "input_workarounds.hpp" #include "ovms_text_streamer.hpp" #include "servable.hpp" #include "text_utils.hpp" @@ -184,17 +185,44 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptr std::string { + const auto& caps = getProperties()->chatTemplateCaps; + if (!caps.requiresObjectArguments && !caps.requiresNonNullContent) { + return jsonBody; // no workarounds needed + } + rapidjson::Document doc; + doc.Parse(jsonBody.c_str()); + if (doc.HasParseError()) { + return jsonBody; + } + input_workarounds::applyToJson(caps, getProperties()->detectedModelFamily, doc); + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + return buffer.GetString(); + }; + std::string inputText; switch (executionContext->endpoint) { case Endpoint::CHAT_COMPLETIONS: { #if (PYTHON_DISABLE == 0) if (getProperties()->chatTemplateMode == ChatTemplateMode::JINJA) { bool success; + auto tplStart = std::chrono::steady_clock::now(); if (executionContext->apiHandler->getProcessedJson().size() > 0) { - success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->apiHandler->getProcessedJson(), inputText); + std::string modifiedJson = applyInputWorkarounds(executionContext->apiHandler->getProcessedJson()); + success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, modifiedJson, inputText); } else { - success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->payload.body, inputText); + std::string modifiedJson = applyInputWorkarounds(executionContext->payload.body); + success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, modifiedJson, inputText); } + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "applyChatTemplate Jinja (chat): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); if (!success) { return absl::Status(absl::StatusCode::kInvalidArgument, inputText); } @@ -202,7 +230,8 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrapiHandler->getChatHistory(); - constexpr bool addGenerationPrompt = true; // confirm it should be hardcoded + input_workarounds::applyToHistory(getProperties()->chatTemplateCaps, getProperties()->detectedModelFamily, chatHistory); + constexpr bool addGenerationPrompt = true; auto toolParsingResult = executionContext->apiHandler->parseToolsToJsonContainer(); if (!toolParsingResult.ok()) { return toolParsingResult.status(); @@ -214,7 +243,11 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrtokenizer.apply_chat_template(chatHistory, addGenerationPrompt, {}, tools, chatTemplateKwargs); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "apply_chat_template (chat): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); } catch (const std::exception& e) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); @@ -232,7 +265,12 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrapiHandler->getChatHistory().size() > 0) { #if (PYTHON_DISABLE == 0) if (getProperties()->chatTemplateMode == ChatTemplateMode::JINJA) { - bool success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->apiHandler->getProcessedJson(), inputText); + std::string modifiedJson = applyInputWorkarounds(executionContext->apiHandler->getProcessedJson()); + auto tplStart = std::chrono::steady_clock::now(); + bool success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, modifiedJson, inputText); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "applyChatTemplate Jinja (responses): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); if (!success) { return absl::Status(absl::StatusCode::kInvalidArgument, inputText); } @@ -240,6 +278,7 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrapiHandler->getChatHistory(); + input_workarounds::applyToHistory(getProperties()->chatTemplateCaps, getProperties()->detectedModelFamily, chatHistory); constexpr bool addGenerationPrompt = true; auto toolParsingResult = executionContext->apiHandler->parseToolsToJsonContainer(); if (!toolParsingResult.ok()) { @@ -252,7 +291,11 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrtokenizer.apply_chat_template(chatHistory, addGenerationPrompt, {}, tools, chatTemplateKwargs); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "apply_chat_template (responses): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); } catch (const std::exception& e) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); diff --git a/src/llm/servable.hpp b/src/llm/servable.hpp index 6d1669735a..2f658edc32 100644 --- a/src/llm/servable.hpp +++ b/src/llm/servable.hpp @@ -36,6 +36,7 @@ #include "../http_payload.hpp" #include "../sse_utils.hpp" #include "apis/openai_api_handler.hpp" +#include "chat_template_caps.hpp" #include "io_processing/generation_config_builder.hpp" #if (PYTHON_DISABLE == 0) #include "py_jinja_template_processor.hpp" @@ -171,6 +172,9 @@ struct GenAiServableProperties { #else ChatTemplateMode chatTemplateMode = ChatTemplateMode::MINJA; #endif + // Chat template analysis + ChatTemplateCaps chatTemplateCaps; + std::string detectedModelFamily; // Sampling DecodingMethod decodingMethod; std::optional maxTokensLimit; diff --git a/src/llm/servable_initializer.cpp b/src/llm/servable_initializer.cpp index 90673fdbc3..58a74e14a5 100644 --- a/src/llm/servable_initializer.cpp +++ b/src/llm/servable_initializer.cpp @@ -14,6 +14,8 @@ // limitations under the License. //***************************************************************************** #include +#include +#include #include #include #include @@ -39,6 +41,8 @@ #include "../logging.hpp" #include "../mediapipe_internal/mediapipe_utils.hpp" #include "../status.hpp" +#include "chat_template_analyzer.hpp" +#include "chat_template_probe.hpp" #include "src/filesystem/filesystem.hpp" #include "../stringutils.hpp" #include "language_model/continuous_batching/servable.hpp" @@ -53,6 +57,80 @@ namespace ovms { static const std::string CHAT_TEMPLATE_WARNING_MESSAGE = "Warning: Chat template has not been loaded properly. Servable will not respond to /chat/completions endpoint."; +// Dry-run probes: render the chat template with synthetic inputs to empirically +// detect what the template requires. This is model-agnostic and tests actual +// template behavior rather than relying solely on string pattern matching. +// Probes for: +// - requiresObjectArguments (workaround: string→object conversion of tool_call arguments) +static void probeServableChatTemplateCaps(std::shared_ptr properties) { + if (properties->tokenizer.get_chat_template().empty()) { + return; + } + if (!properties->chatTemplateCaps.supportsToolCalls) { + return; + } + +#if (PYTHON_DISABLE == 0) + if (properties->chatTemplateMode == ChatTemplateMode::JINJA && properties->templateProcessor.chatTemplate != nullptr) { + // Probe via Python Jinja — this path has NO polyfills, tests raw template behavior + const std::string argNeedle = "probe_needle_xK9m"; + std::string strArgsOutput; + std::string objArgsOutput; + bool strArgsSuccess = false; + bool objArgsSuccess = false; + + std::string strArgsJson = R"({"messages":[{"role":"user","content":"Hello"},{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"probe_fn","arguments":"{\")" + argNeedle + R"(\":\"val\"}"}}]}]})"; + std::string objArgsJson = R"({"messages":[{"role":"user","content":"Hello"},{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"probe_fn","arguments":{")" + argNeedle + R"(":"val"}}}]}]})"; + + try { + auto t0 = std::chrono::steady_clock::now(); + strArgsSuccess = PyJinjaTemplateProcessor::applyChatTemplate(properties->templateProcessor, properties->modelsPath, strArgsJson, strArgsOutput); + auto t1 = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe Jinja (string args): {} us", + std::chrono::duration_cast(t1 - t0).count()); + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe Jinja (string args): exception: {}", e.what()); + } catch (...) {} + + try { + auto t0 = std::chrono::steady_clock::now(); + objArgsSuccess = PyJinjaTemplateProcessor::applyChatTemplate(properties->templateProcessor, properties->modelsPath, objArgsJson, objArgsOutput); + auto t1 = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe Jinja (object args): {} us", + std::chrono::duration_cast(t1 - t0).count()); + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe Jinja (object args): exception: {}", e.what()); + } catch (...) {} + + auto rendersNativeArgs = [&argNeedle](const std::string& output) -> bool { + return output.find("\"" + argNeedle + "\": ") != std::string::npos || + output.find("'" + argNeedle + "': ") != std::string::npos || + output.find("") != std::string::npos || + output.find(argNeedle + ":<|") != std::string::npos; + }; + + bool strArgsRendersNative = strArgsSuccess && rendersNativeArgs(strArgsOutput); + bool objArgsRendersNative = objArgsSuccess && rendersNativeArgs(objArgsOutput); + + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Dry-run probe requiresObjectArguments: strRendersNative={}, objRendersNative={}", + strArgsRendersNative, objArgsRendersNative); + + if (strArgsRendersNative || objArgsRendersNative) { + bool probeResult = objArgsRendersNative; + if (probeResult != properties->chatTemplateCaps.requiresObjectArguments) { + SPDLOG_LOGGER_INFO(llm_calculator_logger, "Dry-run probe overrides requiresObjectArguments: {} -> {}", + properties->chatTemplateCaps.requiresObjectArguments, probeResult); + } + properties->chatTemplateCaps.requiresObjectArguments = probeResult; + } + return; + } +#endif + + // Minja path — use the shared probe component + probeChatTemplateCaps(properties->tokenizer, properties->chatTemplateCaps); +} + void GenAiServableInitializer::loadChatTemplate(std::shared_ptr properties, const std::string& chatTemplateDirectory) { #if (PYTHON_DISABLE == 0) if (properties->chatTemplateMode == ChatTemplateMode::JINJA) { @@ -82,6 +160,42 @@ void GenAiServableInitializer::loadChatTemplate(std::shared_ptrtokenizer.get_chat_template(); + if (!templateSource.empty()) { + auto analysisResult = ChatTemplateAnalyzer::analyze(templateSource); + properties->chatTemplateCaps = analysisResult.caps; + properties->detectedModelFamily = analysisResult.detectedModelFamily; + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Chat template analysis: detectedModelFamily={}, " + "supportsSystemRole={}, supportsTools={}, supportsToolCalls={}, supportsToolResponses={}, " + "requiresObjectArguments={}, requiresNonNullContent={}, requiresTypedContent={}, " + "supportsParallelToolCalls={}, supportsToolCallId={}", + analysisResult.detectedModelFamily.empty() ? "(none)" : analysisResult.detectedModelFamily, + analysisResult.caps.supportsSystemRole, + analysisResult.caps.supportsTools, + analysisResult.caps.supportsToolCalls, + analysisResult.caps.supportsToolResponses, + analysisResult.caps.requiresObjectArguments, + analysisResult.caps.requiresNonNullContent, + analysisResult.caps.requiresTypedContent, + analysisResult.caps.supportsParallelToolCalls, + analysisResult.caps.supportsToolCallId); + // Auto-detect tool parser if not explicitly configured + if (properties->toolParserName.empty() && analysisResult.detectedToolParser.has_value()) { + properties->toolParserName = analysisResult.detectedToolParser.value(); + SPDLOG_LOGGER_INFO(llm_calculator_logger, "Auto-detected tool_parser: {}", properties->toolParserName); + } + // Auto-detect reasoning parser if not explicitly configured + if (properties->reasoningParserName.empty() && analysisResult.detectedReasoningParser.has_value()) { + properties->reasoningParserName = analysisResult.detectedReasoningParser.value(); + SPDLOG_LOGGER_INFO(llm_calculator_logger, "Auto-detected reasoning_parser: {}", properties->reasoningParserName); + } + + // Dry-run probes: empirically verify requiresObjectArguments and requiresNonNullContent + // by rendering synthetic messages through GenAI's minja and checking the output. + probeServableChatTemplateCaps(properties); + } } #if (PYTHON_DISABLE == 0) diff --git a/src/llm/visual_language_model/continuous_batching/servable.cpp b/src/llm/visual_language_model/continuous_batching/servable.cpp index 8b65ac7fe0..d82c40b7e8 100644 --- a/src/llm/visual_language_model/continuous_batching/servable.cpp +++ b/src/llm/visual_language_model/continuous_batching/servable.cpp @@ -16,6 +16,7 @@ #include "servable.hpp" +#include #include #include #include @@ -28,8 +29,9 @@ #include "../../../config.hpp" #include "../../../logging.hpp" -#include "../../../tokenize/tokenize_parser.hpp" +#include "../../input_workarounds.hpp" #include "../../text_utils.hpp" +#include "../../../tokenize/tokenize_parser.hpp" #if (PYTHON_DISABLE == 0) #include "../../py_jinja_template_processor.hpp" #endif @@ -132,14 +134,31 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptrchatTemplateCaps, getProperties()->detectedModelFamily, workaroundDoc); + rapidjson::StringBuffer wBuf; + rapidjson::Writer wWriter(wBuf); + workaroundDoc.Accept(wWriter); + jsonForTemplate = wBuf.GetString(); + } + } + auto tplStart = std::chrono::steady_clock::now(); bool success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, jsonForTemplate, vlmExecutionContext->inputText); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "applyChatTemplate Jinja (vlm-cb): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); if (!success) { return absl::Status(absl::StatusCode::kInvalidArgument, vlmExecutionContext->inputText); } } else // NOLINT(readability/braces) #endif { - constexpr bool addGenerationPrompt = true; // confirm it should be hardcoded + input_workarounds::applyToHistory(getProperties()->chatTemplateCaps, getProperties()->detectedModelFamily, chatHistory); + constexpr bool addGenerationPrompt = true; auto toolParsingResult = vlmExecutionContext->apiHandler->parseToolsToJsonContainer(); if (!toolParsingResult.ok()) { return toolParsingResult.status(); @@ -159,7 +178,11 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptrinputText = properties->tokenizer.apply_chat_template(chatHistory, addGenerationPrompt, {}, tools, chatTemplateKwargs); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "apply_chat_template (vlm-cb): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); } catch (const std::exception& e) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 399682d6c5..af9caadb30 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -14,6 +14,7 @@ // limitations under the License. //***************************************************************************** +#include #include #include #include @@ -42,6 +43,7 @@ #include "../../../config.hpp" #include "../../../http_payload.hpp" #include "../../../mediapipe_internal/mediapipe_utils.hpp" +#include "../../input_workarounds.hpp" #include "../../text_utils.hpp" #include "../../../tokenize/tokenize_parser.hpp" #if (PYTHON_DISABLE == 0) @@ -377,14 +379,31 @@ absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptrchatTemplateCaps, getProperties()->detectedModelFamily, workaroundDoc); + rapidjson::StringBuffer wBuf; + rapidjson::Writer wWriter(wBuf); + workaroundDoc.Accept(wWriter); + jsonForTemplate = wBuf.GetString(); + } + } + auto tplStart = std::chrono::steady_clock::now(); bool success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, jsonForTemplate, vlmExecutionContext->inputText); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "applyChatTemplate Jinja (vlm-legacy): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); if (!success) { return absl::Status(absl::StatusCode::kInvalidArgument, vlmExecutionContext->inputText); } } else // NOLINT(readability/braces) #endif { - constexpr bool addGenerationPrompt = true; // confirm it should be hardcoded + input_workarounds::applyToHistory(getProperties()->chatTemplateCaps, getProperties()->detectedModelFamily, chatHistory); + constexpr bool addGenerationPrompt = true; auto toolParsingResult = vlmExecutionContext->apiHandler->parseToolsToJsonContainer(); if (!toolParsingResult.ok()) { return toolParsingResult.status(); @@ -396,7 +415,11 @@ absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptrinputText = properties->tokenizer.apply_chat_template(chatHistory, addGenerationPrompt, {}, tools, chatTemplateKwargs); + auto tplEnd = std::chrono::steady_clock::now(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "apply_chat_template (vlm-legacy): {} us", + std::chrono::duration_cast(tplEnd - tplStart).count()); } catch (const std::exception& e) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); diff --git a/src/test/llm/chat_template_analyzer_test.cpp b/src/test/llm/chat_template_analyzer_test.cpp new file mode 100644 index 0000000000..1e8d69179d --- /dev/null +++ b/src/test/llm/chat_template_analyzer_test.cpp @@ -0,0 +1,223 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include + +#include "../../llm/chat_template_analyzer.hpp" + +using namespace ovms; + +class ChatTemplateAnalyzerTest : public ::testing::Test {}; + +// --- Empty template --- + +TEST_F(ChatTemplateAnalyzerTest, emptyTemplateReturnsDefaults) { + auto result = ChatTemplateAnalyzer::analyze(""); + EXPECT_TRUE(result.detectedModelFamily.empty()); + EXPECT_FALSE(result.detectedToolParser.has_value()); + EXPECT_FALSE(result.detectedReasoningParser.has_value()); + EXPECT_FALSE(result.caps.supportsToolCalls); + EXPECT_FALSE(result.caps.requiresObjectArguments); +} + +// --- GPT-OSS --- + +TEST_F(ChatTemplateAnalyzerTest, detectsGptOss) { + std::string tmpl = R"({% if message.role == 'assistant' %}<|channel|>{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "gptoss"); + EXPECT_EQ(result.detectedToolParser.value(), "gptoss"); + EXPECT_EQ(result.detectedReasoningParser.value(), "gptoss"); + EXPECT_TRUE(result.caps.supportsToolCalls); + EXPECT_TRUE(result.caps.supportsTools); + EXPECT_TRUE(result.caps.supportsToolResponses); +} + +// --- Gemma4 --- + +TEST_F(ChatTemplateAnalyzerTest, detectsGemma4SingleQuote) { + std::string tmpl = R"({% if tool_call %}'<|tool_call>call:'{{ tool_call.name }}{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "gemma4"); + EXPECT_EQ(result.detectedToolParser.value(), "gemma4"); + EXPECT_EQ(result.detectedReasoningParser.value(), "gemma4"); + EXPECT_TRUE(result.caps.requiresObjectArguments); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsGemma4NoQuote) { + std::string tmpl = R"({% if tool_call %}<|tool_call>call:{{ tool_call.name }}{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "gemma4"); +} + +// --- Qwen3-Coder --- + +TEST_F(ChatTemplateAnalyzerTest, detectsQwen3Coder) { + std::string tmpl = R"({% for param in func.params %}{{ param.value }}{% endfor %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "qwen3coder"); + EXPECT_EQ(result.detectedToolParser.value(), "qwen3coder"); + EXPECT_TRUE(result.caps.supportsToolCalls); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsQwen3CoderWithThinkTags) { + std::string tmpl = R"({{ p.value }}reasoning)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "qwen3coder"); + EXPECT_EQ(result.detectedToolParser.value(), "qwen3coder"); + EXPECT_EQ(result.detectedReasoningParser.value(), "qwen3"); +} + +// --- LFM2 --- + +TEST_F(ChatTemplateAnalyzerTest, detectsLfm2AssistantToolCall) { + std::string tmpl = R"({% if role == 'assistant' %}<|assistant_tool_call|>{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "lfm2"); + EXPECT_EQ(result.detectedToolParser.value(), "lfm2"); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsLfm2ToolCallStart) { + std::string tmpl = R"(<|tool_call_start|>{{ tool_calls }}<|tool_call_end|>)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "lfm2"); + EXPECT_EQ(result.detectedToolParser.value(), "lfm2"); +} + +// --- Phi-4 --- + +TEST_F(ChatTemplateAnalyzerTest, detectsPhi4) { + // Using raw bytes for ▁ (U+2581, UTF-8: E2 96 81) + std::string tmpl = "some template with <|tool\xe2\x96\x81" "call\xe2\x96\x81" "begin|> tag"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "phi4"); + EXPECT_EQ(result.detectedToolParser.value(), "phi4"); +} + +// --- Devstral --- + +TEST_F(ChatTemplateAnalyzerTest, detectsDevstral) { + std::string tmpl = R"({% if tool_calls %}[TOOL_CALLS]{{ tool_calls }}{% endif %}{% if tool_result %}[TOOL_RESULTS]{{ result }}{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "devstral"); + EXPECT_EQ(result.detectedToolParser.value(), "devstral"); +} + +// --- Mistral --- + +TEST_F(ChatTemplateAnalyzerTest, detectsMistralWithToolCalls) { + std::string tmpl = R"({% if tool_calls %}[TOOL_CALLS]{{ tool_calls }}{% endif %} some other stuff)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "mistral"); + EXPECT_EQ(result.detectedToolParser.value(), "mistral"); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsMistralWithAvailableTools) { + std::string tmpl = R"([AVAILABLE_TOOLS]{{ tools }}[/AVAILABLE_TOOLS] template body)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "mistral"); + EXPECT_EQ(result.detectedToolParser.value(), "mistral"); +} + +// --- Llama3 --- + +TEST_F(ChatTemplateAnalyzerTest, detectsLlama3) { + std::string tmpl = R"({% if tool_calls %}<|python_tag|>{{ tool_calls }}{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "llama3"); + EXPECT_EQ(result.detectedToolParser.value(), "llama3"); + EXPECT_TRUE(result.caps.requiresNonNullContent); +} + +// --- Hermes3/Qwen --- + +TEST_F(ChatTemplateAnalyzerTest, detectsHermes3) { + std::string tmpl = R"({% if tool_call %}{{ tool_call }}{% endif %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "hermes3"); + EXPECT_EQ(result.detectedToolParser.value(), "hermes3"); + EXPECT_FALSE(result.detectedReasoningParser.has_value()); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsHermes3WithQwen3Reasoning) { + std::string tmpl = R"({{ tool_call }} and also reasoning)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "hermes3"); + EXPECT_EQ(result.detectedToolParser.value(), "hermes3"); + EXPECT_EQ(result.detectedReasoningParser.value(), "qwen3"); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsHermes3WithContentSplitThink) { + std::string tmpl = R"({{ tool_call }} and content.split('') logic)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "hermes3"); + EXPECT_EQ(result.detectedReasoningParser.value(), "qwen3"); +} + +// --- Reasoning-only --- + +TEST_F(ChatTemplateAnalyzerTest, detectsReasoningOnlyWithThinkTags) { + std::string tmpl = R"({% if reasoning %}{{ reasoning }}{% endif %} no tool call markers)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_TRUE(result.detectedModelFamily.empty()); + EXPECT_FALSE(result.detectedToolParser.has_value()); + EXPECT_EQ(result.detectedReasoningParser.value(), "qwen3"); +} + +TEST_F(ChatTemplateAnalyzerTest, detectsReasoningOnlyWithContentSplit) { + std::string tmpl = R"(some template with content.split('') logic)"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_TRUE(result.detectedModelFamily.empty()); + EXPECT_EQ(result.detectedReasoningParser.value(), "qwen3"); +} + +// --- No detection --- + +TEST_F(ChatTemplateAnalyzerTest, unknownTemplateReturnsEmpty) { + std::string tmpl = R"({% for message in messages %}{{ message.content }}{% endfor %})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_TRUE(result.detectedModelFamily.empty()); + EXPECT_FALSE(result.detectedToolParser.has_value()); + EXPECT_FALSE(result.detectedReasoningParser.has_value()); + EXPECT_FALSE(result.caps.supportsToolCalls); +} + +// --- Priority: Devstral over Mistral --- + +TEST_F(ChatTemplateAnalyzerTest, devstralTakesPriorityOverMistral) { + // Both have [TOOL_CALLS] but Devstral also has [TOOL_RESULTS] + std::string tmpl = R"([TOOL_CALLS]{{ tool_calls }}[TOOL_RESULTS]{{ results }})"; + auto result = ChatTemplateAnalyzer::analyze(tmpl); + EXPECT_EQ(result.detectedModelFamily, "devstral"); + EXPECT_EQ(result.detectedToolParser.value(), "devstral"); +} + +// --- Capabilities struct defaults --- + +TEST_F(ChatTemplateAnalyzerTest, defaultCapsValues) { + ChatTemplateCaps caps; + EXPECT_TRUE(caps.supportsSystemRole); + EXPECT_FALSE(caps.supportsTools); + EXPECT_FALSE(caps.supportsToolCalls); + EXPECT_FALSE(caps.supportsToolResponses); + EXPECT_FALSE(caps.requiresObjectArguments); + EXPECT_FALSE(caps.requiresNonNullContent); + EXPECT_FALSE(caps.requiresTypedContent); + EXPECT_FALSE(caps.supportsParallelToolCalls); + EXPECT_FALSE(caps.supportsToolCallId); +} diff --git a/src/test/llm/chat_template_end_to_end_test.cpp b/src/test/llm/chat_template_end_to_end_test.cpp new file mode 100644 index 0000000000..860eaed560 --- /dev/null +++ b/src/test/llm/chat_template_end_to_end_test.cpp @@ -0,0 +1,236 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include +#include +#include + +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include +#pragma GCC diagnostic pop + +#include "../../llm/chat_template_analyzer.hpp" +#include "../../llm/chat_template_caps.hpp" +#include "../../llm/chat_template_probe.hpp" +#include "../../llm/input_workarounds.hpp" +#include "../platform_utils.hpp" + +using namespace ovms; + +// Chat template applicator type +enum class TemplateApplicator { + MINJA, + JINJA // Not implemented yet +}; + +// Test fixture providing end-to-end: analyze → probe → apply workarounds → apply template +class ChatTemplateEndToEndTest : public ::testing::Test { +protected: + const std::string& tokenizerPath = getGenericFullPathForSrcTest("/ovms/src/test/llm_testing/facebook/opt-125m", false); + const std::string& chatTemplatesPath = getGenericFullPathForSrcTest("/ovms/src/test/llm/chat_templates", false); + + std::string savedLogLevel; + + void SetUp() override { + const char* prev = std::getenv("OPENVINO_LOG_LEVEL"); + savedLogLevel = prev ? prev : ""; + setenv("OPENVINO_LOG_LEVEL", "0", 1); + } + + void TearDown() override { + if (savedLogLevel.empty()) { + unsetenv("OPENVINO_LOG_LEVEL"); + } else { + setenv("OPENVINO_LOG_LEVEL", savedLogLevel.c_str(), 1); + } + } + + // --- Inputs (set by each test) --- + std::string chatTemplate; + TemplateApplicator applicator = TemplateApplicator::MINJA; + ov::genai::ChatHistory chatHistory; + + // --- Derived state (populated by run()) --- + ChatTemplateAnalysisResult analysisResult; + ChatTemplateCaps caps; + std::string appliedOutput; + bool applySuccess = false; + + // Load template from file + static std::string loadTemplateFile(const std::string& path) { + std::ifstream file(path); + if (!file.is_open()) { + return ""; + } + return std::string((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + } + + // Run the full pipeline: analyze → probe → workarounds → apply + void run(bool addGenerationPrompt = true) { + ASSERT_FALSE(chatTemplate.empty()) << "chatTemplate must be set before calling run()"; + ASSERT_FALSE(chatHistory.empty()) << "chatHistory must be set before calling run()"; + + // Step 1: Static analysis + analysisResult = ChatTemplateAnalyzer::analyze(chatTemplate); + caps = analysisResult.caps; + + std::cout << "=== Analysis ===" << std::endl; + std::cout << " modelFamily: " << analysisResult.detectedModelFamily << std::endl; + std::cout << " toolParser: " << analysisResult.detectedToolParser.value_or("(none)") << std::endl; + std::cout << " reasoningParser: " << analysisResult.detectedReasoningParser.value_or("(none)") << std::endl; + std::cout << " supportsToolCalls: " << caps.supportsToolCalls << std::endl; + std::cout << " requiresObjectArguments: " << caps.requiresObjectArguments << std::endl; + std::cout << " requiresNonNullContent: " << caps.requiresNonNullContent << std::endl; + + // Step 2: Probe (only if template supports tools) + if (caps.supportsToolCalls) { + ov::genai::Tokenizer probeTokenizer(tokenizerPath); + probeTokenizer.set_chat_template(chatTemplate); + probeChatTemplateCaps(probeTokenizer, caps); + } + + std::cout << "=== After Probe ===" << std::endl; + std::cout << " requiresObjectArguments: " << caps.requiresObjectArguments << std::endl; + + // Step 3: Apply workarounds to the chat history + if (applicator == TemplateApplicator::MINJA) { + input_workarounds::applyToHistory(caps, analysisResult.detectedModelFamily, chatHistory); + } else { + GTEST_SKIP() << "JINJA applicator not implemented yet"; + } + + // Step 4: Apply chat template + if (applicator == TemplateApplicator::MINJA) { + ov::genai::Tokenizer tokenizer(tokenizerPath); + tokenizer.set_chat_template(chatTemplate); + try { + appliedOutput = tokenizer.apply_chat_template(chatHistory, addGenerationPrompt); + applySuccess = true; + } catch (const std::exception& e) { + std::cout << "apply_chat_template FAILED: " << e.what() << std::endl; + applySuccess = false; + } + } + + std::cout << "=== Result ===" << std::endl; + std::cout << appliedOutput << std::endl; + } +}; + +// ============================================================================= +// Example: gpt-oss-20b with tool call containing string arguments +// The probe should detect requiresObjectArguments=true, workaround should convert +// string args to object, and the final template should render them natively. +// ============================================================================= +TEST_F(ChatTemplateEndToEndTest, GptOss_ToolCallWithStringArgs) { + // Load the real gpt-oss chat template + chatTemplate = loadTemplateFile(chatTemplatesPath + "/chat_template_gpt_oss.jinja"); + ASSERT_FALSE(chatTemplate.empty()) << "Failed to load gpt-oss template"; + + // Simulate a request with tool_calls where arguments are a JSON string + // (as sent by most OpenAI-compatible clients) + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"user","content":"What's the weather in Paris?"})")); + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"assistant","content":"","tool_calls":[{"id":"call_abc123","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Paris\",\"unit\":\"celsius\"}"}}]})")); + + run(false); + + ASSERT_TRUE(applySuccess); + + std::string expectedOutput = R"(<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2026-06-25 + +Reasoning: medium + +# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>What's the weather in Paris?<|end|><|start|>assistant to=functions.get_weather <|channel|>commentary json<|message|>{"location": "Paris", "unit": "celsius"}<|end|>)"; + EXPECT_EQ(appliedOutput, expectedOutput); +} + +// ============================================================================= +// Example: Qwen3.6-35B-A3B-int4-ov with tool call containing string arguments +// The probe should detect requiresObjectArguments=true, workaround should convert +// string args to object, and the final template should render them natively. +// ============================================================================= +TEST_F(ChatTemplateEndToEndTest, Qwen36_ToolCallWithStringArgs) { + // Load the real qwen chat template + chatTemplate = loadTemplateFile(chatTemplatesPath + "/chat_template_qwen36.jinja"); + ASSERT_FALSE(chatTemplate.empty()) << "Failed to load qwen36 template"; + + // Simulate a request with tool_calls where arguments are a JSON string + // (as sent by most OpenAI-compatible clients) + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"user","content":"What's the weather in Paris?"})")); + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"assistant","content":"","tool_calls":[{"id":"call_abc123","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Paris\",\"unit\":\"celsius\"}"}}]})")); + + run(false); + + ASSERT_TRUE(applySuccess); + + std::string expectedOutput = R"(<|im_start|>user +What's the weather in Paris?<|im_end|> +<|im_start|>assistant + + + + + + + +Paris + + +celsius + + +<|im_end|> +)"; + EXPECT_EQ(appliedOutput, expectedOutput); +} + +// ============================================================================= +// Example: Gemma4 with tool call containing string arguments +// The probe should detect requiresObjectArguments=true via the needle:<| pattern, +// workaround should convert string args to object, and template should render +// them in Gemma's native key:<|"|>value<|"|> format. +// ============================================================================= +TEST_F(ChatTemplateEndToEndTest, Gemma4_ToolCallWithStringArgs) { + chatTemplate = loadTemplateFile(chatTemplatesPath + "/chat_template_gemma.jinja"); + ASSERT_FALSE(chatTemplate.empty()) << "Failed to load gemma template"; + + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"user","content":"What's the weather in Paris?"})")); + chatHistory.push_back(ov::genai::JsonContainer::from_json_string( + R"({"role":"assistant","content":"","tool_calls":[{"id":"call_abc123","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Paris\",\"unit\":\"celsius\"}"}}]})")); + + run(false); + + ASSERT_TRUE(applySuccess); + + // FIXME: Why is here? because of facebook-opt125? + std::string expectedOutput = R"(<|turn>user +What's the weather in Paris? +<|turn>model +<|tool_call>call:get_weather{location:<|"|>Paris<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>)"; + EXPECT_EQ(appliedOutput, expectedOutput); +} diff --git a/src/test/llm/chat_templates/chat_template_gemma.jinja b/src/test/llm/chat_templates/chat_template_gemma.jinja new file mode 100644 index 0000000000..f62ca843a4 --- /dev/null +++ b/src/test/llm/chat_templates/chat_template_gemma.jinja @@ -0,0 +1,354 @@ +{%- macro format_parameters(properties, required, filter_keys=false) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if not filter_keys or key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'OBJECT' -%} + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + properties:{ + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + } + {%- elif value is mapping -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + properties:{ + {{- format_parameters(value, value['required'] | default([]), filter_keys=true) -}} + } + {%- endif -%} + {%- if value['required'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} + +{%- macro format_tool_response_block(tool_name, response) -%} + {{- '<|tool_response>' -}} + {%- if response is mapping -%} + {{- 'response:' + tool_name + '{' -}} + {%- for key, value in response | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None) -%} +{%- set loop_messages = messages -%} +{{- bos_token -}} +{#- Handle System/Tool Definitions Block -#} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + {#- Inject Thinking token at the very top of the FIRST system turn -#} + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>\n' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {%- if messages[0]['content'] is string -%} + {{- messages[0]['content'] | trim -}} + {%- elif messages[0]['content'] is sequence -%} + {%- for item in messages[0]['content'] -%} + {{- item['text'] | trim + ' '-}} + {%- endfor -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + {{- '\n' -}} +{%- endif %} + +{#- Pre-scan: find last user message index for reasoning guard -#} +{%- set ns_turn = namespace(last_user_idx=-1) -%} +{%- for i in range(loop_messages | length) -%} + {%- if loop_messages[i]['role'] == 'user' -%} + {%- set ns_turn.last_user_idx = i -%} + {%- endif -%} +{%- endfor -%} + +{#- Loop through messages -#} +{%- for message in loop_messages -%} + {%- if message['role'] != 'tool' -%} + {%- set ns.prev_message_type = None -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#} + {%- set prev_nt = namespace(role=None, found=false) -%} + {%- if loop.index0 > 0 -%} + {%- for j in range(loop.index0 - 1, -1, -1) -%} + {%- if not prev_nt.found -%} + {%- if loop_messages[j]['role'] != 'tool' -%} + {%- set prev_nt.role = loop_messages[j]['role'] -%} + {%- set prev_nt.found = true -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + {%- if not continue_same_model_turn -%} + {{- '<|turn>' + role + '\n' }} + {%- endif -%} + + {#- Render reasoning/reasoning_content as thinking channel -#} + {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%} + {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + {{- '<|channel>thought\n' + thinking_text + '\n' -}} + {%- endif -%} + + {%- if message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- set ns_tr_out = namespace(flag=false) -%} + {%- if message.get('tool_responses') -%} + {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#} + {%- for tool_response in message['tool_responses'] -%} + {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endfor -%} + {%- elif message.get('tool_calls') -%} + {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#} + {%- set ns_tool_scan = namespace(stopped=false) -%} + {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + {%- if ns_tool_scan.stopped -%} + {%- elif loop_messages[k]['role'] != 'tool' -%} + {%- set ns_tool_scan.stopped = true -%} + {%- else -%} + {%- set follow = loop_messages[k] -%} + {#- Resolve tool_call_id to function name -#} + {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + {%- for tc in message['tool_calls'] -%} + {%- if tc.get('id') == follow.get('tool_call_id') -%} + {%- set ns_tname.name = tc['function']['name'] -%} + {%- endif -%} + {%- endfor -%} + {#- Handle content as string or content-parts array -#} + {%- set tool_body = follow.get('content') -%} + {%- if tool_body is string -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- elif tool_body is sequence and tool_body is not string -%} + {%- set ns_txt = namespace(s='') -%} + {%- for part in tool_body -%} + {%- if part.get('type') == 'text' -%} + {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + {%- endif -%} + {%- endfor -%} + {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + {%- else -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- endif -%} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- set captured_content -%} + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '<|image|>' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '<|video|>' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- endset -%} + + {{- captured_content -}} + {%- set has_content = captured_content | trim | length > 0 -%} + + {%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%} + {{- '<|tool_response>' -}} + {%- elif not (ns_tr_out.flag and not has_content) -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%} + {{- '<|turn>model\n' -}} + {%- if not enable_thinking | default(false) -%} + {{- '<|channel>thought\n' -}} + {%- endif -%} + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/src/test/llm/chat_templates/chat_template_gpt_oss.jinja b/src/test/llm/chat_templates/chat_template_gpt_oss.jinja new file mode 100644 index 0000000000..f8a00c230e --- /dev/null +++ b/src/test/llm/chat_templates/chat_template_gpt_oss.jinja @@ -0,0 +1,367 @@ +{#- + Modifications to original chat template: + * Allowing reasoning_effort=none; this automatically adds empty reasoning channel at the end. It should force the model to follow with regular response/tool call immediately. + IMPORTANT: When none used, reasoning_effort is rendered as low in the reasoning definition slot (as it is the lowest possible reasoning during model training). + * Removed exception when chat history contains both: content and reasoning_content. + BFCL benchmark requests contain both: reasoning & regular content in SINGLE history turn. + Instead of exception, the regular content is rendered and reasoning_content omitted (assuming regular content might be more insightful) + * Replaced thinking with reasoning_content. For some reason OpenAI used thinking field to render reasoning from history. + Replaced with reasoning_content which is present in history when benchmarking with BFCL. + * Added special handling for empty tool_calls field when rendering chat history. + In some cases gpt-oss generates reasoning/content but no tool calls. + BFCL sends empty array and chat template accessed index 0 assuming there always is some tool call. New chat template ignores empty arrays now. + !* Removed "|tojson" from tool argument rendering. This introduced string escaping drastically influenced following generations. OpenAI Harmony format assumes no escaping of arguments. + This was related to both: function call output (result from mcp servers) and function call arguments (input to mcp servers) in chat history. +#} +{#- + In addition to the normal inputs of `messages` and `tools`, this template also accepts the + following kwargs: + - "builtin_tools": A list, can contain "browser" and/or "python". + - "model_identity": A string that optionally describes the model identity. + - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium". +#} + +{#- Tool Definition Rendering ============================================== #} +{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%} + {%- if param_spec.type == "array" -%} + {%- if param_spec['items'] -%} + {%- if param_spec['items']['type'] == "string" -%} + {{- "string[]" }} + {%- elif param_spec['items']['type'] == "number" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "integer" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "boolean" -%} + {{- "boolean[]" }} + {%- else -%} + {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%} + {%- if inner_type == "object | object" or inner_type|length > 50 -%} + {{- "any[]" }} + {%- else -%} + {{- inner_type + "[]" }} + {%- endif -%} + {%- endif -%} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- else -%} + {{- "any[]" }} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%} + {#- Handle array of types like ["object", "object"] from Union[dict, list] #} + {%- if param_spec.type | length > 1 -%} + {{- param_spec.type | join(" | ") }} + {%- else -%} + {{- param_spec.type[0] }} + {%- endif -%} + {%- elif param_spec.oneOf -%} + {#- Handle oneOf schemas - check for complex unions and fallback to any #} + {%- set has_object_variants = false -%} + {%- for variant in param_spec.oneOf -%} + {%- if variant.type == "object" -%} + {%- set has_object_variants = true -%} + {%- endif -%} + {%- endfor -%} + {%- if has_object_variants and param_spec.oneOf|length > 1 -%} + {{- "any" }} + {%- else -%} + {%- for variant in param_spec.oneOf -%} + {{- render_typescript_type(variant, required_params) -}} + {%- if variant.description %} + {{- "// " + variant.description }} + {%- endif -%} + {%- if variant.default is defined %} + {{ "// default: " + variant.default|tojson }} + {%- endif -%} + {%- if not loop.last %} + {{- " | " }} + {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- elif param_spec.type == "string" -%} + {%- if param_spec.enum -%} + {{- '"' + param_spec.enum|join('" | "') + '"' -}} + {%- else -%} + {{- "string" }} + {%- if param_spec.nullable %} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type == "number" -%} + {{- "number" }} + {%- elif param_spec.type == "integer" -%} + {{- "number" }} + {%- elif param_spec.type == "boolean" -%} + {{- "boolean" }} + + {%- elif param_spec.type == "object" -%} + {%- if param_spec.properties -%} + {{- "{\n" }} + {%- for prop_name, prop_spec in param_spec.properties.items() -%} + {{- prop_name -}} + {%- if prop_name not in (param_spec.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{ render_typescript_type(prop_spec, param_spec.required or []) }} + {%- if not loop.last -%} + {{-", " }} + {%- endif -%} + {%- endfor -%} + {{- "}" }} + {%- else -%} + {{- "object" }} + {%- endif -%} + {%- else -%} + {{- "any" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_tool_namespace(namespace_name, tools) -%} + {{- "## " + namespace_name + "\n\n" }} + {{- "namespace " + namespace_name + " {\n\n" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- "// " + tool.description + "\n" }} + {{- "type "+ tool.name + " = " }} + {%- if tool.parameters and tool.parameters.properties %} + {{- "(_: {\n" }} + {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- if param_spec.description %} + {{- "// " + param_spec.description + "\n" }} + {%- endif %} + {{- param_name }} + {%- if param_name not in (tool.parameters.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {%- if param_spec.default is defined -%} + {%- if param_spec.enum %} + {{- ", // default: " + param_spec.default }} + {%- elif param_spec.oneOf %} + {{- "// default: " + param_spec.default }} + {%- else %} + {{- ", // default: " + param_spec.default|tojson }} + {%- endif -%} + {%- endif -%} + {%- if not loop.last %} + {{- ",\n" }} + {%- else %} + {{- ",\n" }} + {%- endif -%} + {%- endfor %} + {{- "}) => any;\n\n" }} + {%- else -%} + {{- "() => any;\n\n" }} + {%- endif -%} + {%- endfor %} + {{- "} // namespace " + namespace_name }} +{%- endmacro -%} + +{%- macro render_builtin_tools(browser_tool, python_tool) -%} + {%- if browser_tool %} + {{- "## browser\n\n" }} + {{- "// Tool for browsing.\n" }} + {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }} + {{- "// Cite information from the tool using the following format:\n" }} + {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }} + {{- "// Do not quote more than 10 words directly from the tool output.\n" }} + {{- "// sources=web (default: web)\n" }} + {{- "namespace browser {\n\n" }} + {{- "// Searches for information related to `query` and displays `topn` results.\n" }} + {{- "type search = (_: {\n" }} + {{- "query: string,\n" }} + {{- "topn?: number, // default: 10\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }} + {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }} + {{- "// If `cursor` is not provided, the most recent page is implied.\n" }} + {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }} + {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }} + {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }} + {{- "type open = (_: {\n" }} + {{- "id?: number | string, // default: -1\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "loc?: number, // default: -1\n" }} + {{- "num_lines?: number, // default: -1\n" }} + {{- "view_source?: boolean, // default: false\n" }} + {{- "source?: string,\n" }} + {{- "}) => any;\n\n" }} + {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }} + {{- "type find = (_: {\n" }} + {{- "pattern: string,\n" }} + {{- "cursor?: number, // default: -1\n" }} + {{- "}) => any;\n\n" }} + {{- "} // namespace browser\n\n" }} + {%- endif -%} + + {%- if python_tool %} + {{- "## python\n\n" }} + {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }} + {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }} + {%- endif -%} +{%- endmacro -%} + +{#- System Message Construction ============================================ #} +{%- macro build_system_message() -%} + {%- if model_identity is not defined %} + {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %} + {%- endif %} + {{- model_identity + "\n" }} + {{- "Knowledge cutoff: 2024-06\n" }} + {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }} + {%- if reasoning_effort is not defined %} + {%- set reasoning_effort = "medium" %} + {%- endif %} + {#- {{- "Reasoning: " + reasoning_effort + "\n\n" }} #} + {%- set display_effort = reasoning_effort %} + {%- if reasoning_effort == "none" %} + {%- set display_effort = "low" %} + {%- endif %} + {{- "Reasoning: " + display_effort + "\n\n" }} + {%- if builtin_tools %} + {{- "# Tools\n\n" }} + {%- set available_builtin_tools = namespace(browser=false, python=false) %} + {%- for tool in builtin_tools %} + {%- if tool == "browser" %} + {%- set available_builtin_tools.browser = true %} + {%- elif tool == "python" %} + {%- set available_builtin_tools.python = true %} + {%- endif %} + {%- endfor %} + {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }} + {%- endif -%} + {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }} + {%- if tools -%} + {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }} + {%- endif -%} +{%- endmacro -%} + +{#- Main Template Logic ================================================= #} +{#- Set defaults #} + +{#- Render system message #} +{{- "<|start|>system<|message|>" }} +{{- build_system_message() }} +{{- "<|end|>" }} + +{#- Extract developer message #} +{%- if messages[0].role == "developer" or messages[0].role == "system" %} + {%- set developer_message = messages[0].content %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set developer_message = "" %} + {%- set loop_messages = messages %} +{%- endif %} + +{#- Render developer message #} +{%- if developer_message or tools %} + {{- "<|start|>developer<|message|>" }} + {%- if developer_message %} + {{- "# Instructions\n\n" }} + {{- developer_message }} + {{- "\n\n" }} + {%- endif %} + {%- if tools -%} + {{- "# Tools\n\n" }} + {{- render_tool_namespace("functions", tools) }} + {%- endif -%} + {{- "<|end|>" }} +{%- endif %} + +{#- Render messages #} +{%- set last_tool_call = namespace(name=none) %} +{%- for message in loop_messages -%} + {#- At this point only assistant/user/tool messages should remain #} + {%- if message.role == 'assistant' -%} + {#- Checks to ensure the messages are being passed in the format we expect #} + {%- if "content" in message and message.content %} + {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'reasoning_content' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} + {%- if "reasoning_content" in message and message.reasoning_content %} + {%- if "<|channel|>analysis<|message|>" in message.reasoning_content or "<|channel|>final<|message|>" in message.reasoning_content %} + {{- raise_exception("You have passed a message containing <|channel|> tags in the reasoning_content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'reasoning_content' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }} + {%- endif %} + {%- endif %} + {%- if "tool_calls" in message %} + {#- We need very careful handling here - we want to drop the tool call analysis message if the model #} + {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #} + {#- when we render CoT/analysis messages in inference. #} + {%- set future_final_message = namespace(found=false) %} + {%- for future_message in loop_messages[loop.index:] %} + {%- if future_message.role == 'assistant' and "tool_calls" not in future_message %} + {%- set future_final_message.found = true %} + {%- endif %} + {%- endfor %} + {%- if message.content and message.reasoning_content and not future_final_message.found %} + {#- Original: {{- raise_exception("Cannot pass both content and reasoning_content in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }} #} + {#- Mod: Exception suppressed, multi-turn BFCL benchmark contains such situations. #} + {#- Prefer rendering content over reasoning when both are available, looks like it contains more information. #} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }} + {%- elif message.content and not future_final_message.found %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }} + {%- elif message.reasoning_content and not future_final_message.found %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.reasoning_content + "<|end|>" }} + {%- endif %} + {#- Mod: this check was not present, causing crashes if tool_calls array was empty #} + {%- if message.tool_calls|length > 0 %} + {#- We assume max 1 tool call per message, and so we infer the tool call name #} + {#- in "tool" messages from the most recent assistant tool call name #} + {%- set tool_call = message.tool_calls[0] %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- "<|start|>assistant to=" }} + {{- "functions." + tool_call.name + " <|channel|>commentary " }} + {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }} + {{- tool_call.arguments|tojson }} + {#- Original: {{- "<|call|>" }} #} + {#- https://cookbook.openai.com/articles/openai-harmony#handling-tool-calls #} + {#- Found out in OpenAI Harmony docs it should be replaced with <|end|> in history rendering: #} + {{- "<|end|>" }} + {%- set last_tool_call.name = tool_call.name %} + {%- endif %} + + {%- elif loop.last and not add_generation_prompt %} + {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #} + {#- This is a situation that should only occur in training, never in inference. #} + {%- if "thinking" in message %} + {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }} + {%- endif %} + {#- <|return|> indicates the end of generation, but <|end|> does not #} + {#- <|return|> should never be an input to the model, but we include it as the final token #} + {#- when training, so the model learns to emit it. #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|return|>" }} + {%- else %} + {#- CoT is dropped during all previous turns, so we never render it for inference #} + {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }} + {%- set last_tool_call.name = none %} + {%- endif %} + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none %} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif %} + {{- "<|start|>functions." + last_tool_call.name }} + {#- Original: {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }} #} + {#- Actual version that works, does not escape and allows non-json: #} + {{- " to=assistant<|channel|>commentary<|message|>" + message.content + "<|end|>" -}} + {%- elif message.role == 'user' -%} + {{- "<|start|>user<|message|>" + message.content + "<|end|>" }} + {%- endif -%} +{%- endfor -%} + +{#- Generation prompt #} +{%- if add_generation_prompt -%} + {%- if reasoning_effort == "none" -%} + {{- "<|start|>assistant<|channel|>analysis<|message|><|end|><|start|>assistant" }} + {%- else -%} + {{- "<|start|>assistant" }} + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/src/test/llm/chat_templates/chat_template_qwen36.jinja b/src/test/llm/chat_templates/chat_template_qwen36.jinja new file mode 100644 index 0000000000..dc0d1b21a1 --- /dev/null +++ b/src/test/llm/chat_templates/chat_template_qwen36.jinja @@ -0,0 +1,176 @@ +{#- + Modifications to original chat template: + * Expanded tool_call args serialization to handle booleans/None correctly — + Jinja's tojson outputs JSON (true/false/null) but the model expects + Python literals (True/False/None), so we patch them explicitly. +#}{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if (preserve_thinking is defined and preserve_thinking is true) or (loop.index0 > ns.last_query_index) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {#- BEGINNING OF PATCH #} + {#- {%- set args_value = args_value | string if args_value is string else args_value | tojson | safe %} #} + {#- Here we ensure boolean values/None end up capital letter #} + {%- if args_value is string %} + {# leave as-is #} + {%- elif args_value is none %} + {%- set args_value = 'None' %} + {%- elif args_value is mapping or (args_value is iterable and args_value is not string) %} + {%- set args_value = args_value | tojson | safe %} + {%- else %} + {# scalar non-string: bool or number — tojson then patch booleans to Python style #} + {%- set args_value = args_value | tojson | safe %} + {%- if args_value == 'true' %} + {%- set args_value = 'True' %} + {%- elif args_value == 'false' %} + {%- set args_value = 'False' %} + {%- endif %} + {%- endif %} + {#- END OF PATCH #} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/src/test/llm/input_workarounds_test.cpp b/src/test/llm/input_workarounds_test.cpp new file mode 100644 index 0000000000..5cb778cbb6 --- /dev/null +++ b/src/test/llm/input_workarounds_test.cpp @@ -0,0 +1,276 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include +#include +#include +#include + +#include "../../llm/input_workarounds.hpp" + +using namespace ovms; + +class InputWorkaroundsTest : public ::testing::Test { +protected: + rapidjson::Document parseJson(const std::string& json) { + rapidjson::Document doc; + doc.Parse(json.c_str()); + EXPECT_FALSE(doc.HasParseError()) << "Failed to parse test JSON"; + return doc; + } + + std::string serializeJson(const rapidjson::Document& doc) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + return buffer.GetString(); + } +}; + +// --- funcArgsToObjectJson --- + +TEST_F(InputWorkaroundsTest, funcArgsToObjectConvertsStringArgs) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": null, "tool_calls": [ + {"id": "call_1", "type": "function", "function": { + "name": "get_weather", + "arguments": "{\"city\": \"London\", \"units\": \"celsius\"}" + }} + ]} + ] + })"); + + input_workarounds::funcArgsToObjectJson(doc); + + auto& args = doc["messages"][1]["tool_calls"][0]["function"]["arguments"]; + ASSERT_TRUE(args.IsObject()); + ASSERT_TRUE(args.HasMember("city")); + EXPECT_STREQ(args["city"].GetString(), "London"); + ASSERT_TRUE(args.HasMember("units")); + EXPECT_STREQ(args["units"].GetString(), "celsius"); +} + +TEST_F(InputWorkaroundsTest, funcArgsToObjectHandlesMultipleToolCalls) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": null, "tool_calls": [ + {"id": "call_1", "function": {"name": "fn1", "arguments": "{\"a\": 1}"}}, + {"id": "call_2", "function": {"name": "fn2", "arguments": "{\"b\": true}"}} + ]} + ] + })"); + + input_workarounds::funcArgsToObjectJson(doc); + + auto& args1 = doc["messages"][0]["tool_calls"][0]["function"]["arguments"]; + ASSERT_TRUE(args1.IsObject()); + EXPECT_EQ(args1["a"].GetInt(), 1); + + auto& args2 = doc["messages"][0]["tool_calls"][1]["function"]["arguments"]; + ASSERT_TRUE(args2.IsObject()); + EXPECT_TRUE(args2["b"].GetBool()); +} + +TEST_F(InputWorkaroundsTest, funcArgsToObjectSkipsAlreadyObjectArgs) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": {"key": "value"}}} + ]} + ] + })"); + + input_workarounds::funcArgsToObjectJson(doc); + + auto& args = doc["messages"][0]["tool_calls"][0]["function"]["arguments"]; + ASSERT_TRUE(args.IsObject()); + EXPECT_STREQ(args["key"].GetString(), "value"); +} + +TEST_F(InputWorkaroundsTest, funcArgsToObjectSkipsInvalidJsonString) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": "not valid json {"}} + ]} + ] + })"); + + input_workarounds::funcArgsToObjectJson(doc); + + auto& args = doc["messages"][0]["tool_calls"][0]["function"]["arguments"]; + EXPECT_TRUE(args.IsString()); +} + +TEST_F(InputWorkaroundsTest, funcArgsToObjectNoopWithoutMessages) { + auto doc = parseJson(R"({"model": "test"})"); + input_workarounds::funcArgsToObjectJson(doc); + // Should not crash + EXPECT_TRUE(doc.HasMember("model")); +} + +TEST_F(InputWorkaroundsTest, funcArgsToObjectNoopWithoutToolCalls) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "user", "content": "hello"} + ] + })"); + input_workarounds::funcArgsToObjectJson(doc); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), "hello"); +} + +// --- ensureNonNullContentJson --- + +TEST_F(InputWorkaroundsTest, ensureNonNullContentSetsNullToEmpty) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": null, "tool_calls": [ + {"function": {"name": "fn"}} + ]} + ] + })"); + + input_workarounds::ensureNonNullContentJson(doc); + + ASSERT_TRUE(doc["messages"][0]["content"].IsString()); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), ""); +} + +TEST_F(InputWorkaroundsTest, ensureNonNullContentAddsMissingContent) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "tool_calls": [ + {"function": {"name": "fn"}} + ]} + ] + })"); + + input_workarounds::ensureNonNullContentJson(doc); + + ASSERT_TRUE(doc["messages"][0].HasMember("content")); + ASSERT_TRUE(doc["messages"][0]["content"].IsString()); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), ""); +} + +TEST_F(InputWorkaroundsTest, ensureNonNullContentPreservesExistingString) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": "some text", "tool_calls": [ + {"function": {"name": "fn"}} + ]} + ] + })"); + + input_workarounds::ensureNonNullContentJson(doc); + + ASSERT_TRUE(doc["messages"][0]["content"].IsString()); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), "some text"); +} + +TEST_F(InputWorkaroundsTest, ensureNonNullContentSkipsMessagesWithoutToolCalls) { + auto doc = parseJson(R"({ + "messages": [ + {"role": "user", "content": null} + ] + })"); + + input_workarounds::ensureNonNullContentJson(doc); + + // User message without tool_calls should not be modified + EXPECT_TRUE(doc["messages"][0]["content"].IsNull()); +} + +// --- applyToJson --- + +TEST_F(InputWorkaroundsTest, applyToJsonAppliesObjectArgsWhenRequired) { + ChatTemplateCaps caps; + caps.requiresObjectArguments = true; + + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": "{\"x\": 42}"}} + ]} + ] + })"); + + input_workarounds::applyToJson(caps, "gemma4", doc); + + ASSERT_TRUE(doc["messages"][0]["tool_calls"][0]["function"]["arguments"].IsObject()); + EXPECT_EQ(doc["messages"][0]["tool_calls"][0]["function"]["arguments"]["x"].GetInt(), 42); +} + +TEST_F(InputWorkaroundsTest, applyToJsonAppliesNonNullContentWhenRequired) { + ChatTemplateCaps caps; + caps.requiresNonNullContent = true; + + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": null, "tool_calls": [ + {"function": {"name": "fn"}} + ]} + ] + })"); + + input_workarounds::applyToJson(caps, "llama3", doc); + + ASSERT_TRUE(doc["messages"][0]["content"].IsString()); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), ""); +} + +TEST_F(InputWorkaroundsTest, applyToJsonDoesNothingWhenNoCapsSet) { + ChatTemplateCaps caps; // all defaults (false) + + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": null, "tool_calls": [ + {"function": {"name": "fn", "arguments": "{\"x\": 1}"}} + ]} + ] + })"); + + std::string before = serializeJson(doc); + input_workarounds::applyToJson(caps, "", doc); + std::string after = serializeJson(doc); + + EXPECT_EQ(before, after); +} + +TEST_F(InputWorkaroundsTest, applyToJsonAppliesBothWorkarounds) { + ChatTemplateCaps caps; + caps.requiresObjectArguments = true; + caps.requiresNonNullContent = true; + + auto doc = parseJson(R"({ + "messages": [ + {"role": "assistant", "content": null, "tool_calls": [ + {"function": {"name": "fn", "arguments": "{\"key\": \"val\"}"}} + ]} + ] + })"); + + input_workarounds::applyToJson(caps, "test", doc); + + // Arguments should be converted to object + ASSERT_TRUE(doc["messages"][0]["tool_calls"][0]["function"]["arguments"].IsObject()); + // Content should be non-null + ASSERT_TRUE(doc["messages"][0]["content"].IsString()); + EXPECT_STREQ(doc["messages"][0]["content"].GetString(), ""); +}