diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 05c4dd81b7..ef29ea8a62 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -65,7 +65,10 @@ When analyzing a Pull Request, follow this protocol: - **Keep headers self-contained but minimal**: each header must compile on its own, but should not pull in transitive dependencies that callers don't need. - **Prefer opaque types / Pimpl**: for complex implementation details, consider the Pimpl idiom to keep implementation-only types out of the public header entirely. - **Never include a header solely for a typedef or enum**: forward-declare the enum (`enum class Foo;` in C++17) or relocate the typedef to a lightweight `fwd.hpp`-style header. -13. Be mindful when accepting `const T&` in constructors or functions that store the reference: verify that the referenced object's lifetime outlives the usage to avoid dangling references. +13. **No dangling references or temporaries bound to `const T&`**: + - Never use `const T&` parameters with default arguments that construct temporaries (e.g. `const std::string& param = ""`). This binds a reference to a temporary — use a function overload instead, or pass by value. + - When accepting `const T&` in constructors or functions that store the reference, verify that the referenced object's lifetime outlives the usage to avoid dangling references. + - Prefer overloads over default arguments for non-trivial types passed by reference. ## Build System diff --git a/demos/common/export_models/export_model.py b/demos/common/export_models/export_model.py index cc83582ada..007ba45a80 100644 --- a/demos/common/export_models/export_model.py +++ b/demos/common/export_models/export_model.py @@ -86,6 +86,13 @@ def add_common_arguments(parser): parser_image_generation.add_argument('--max_num_images_per_prompt', type=int, default=0, help='Max allowed number of images client is allowed to request for a given prompt', dest='max_num_images_per_prompt') parser_image_generation.add_argument('--default_num_inference_steps', type=int, default=0, help='Default number of inference steps when not specified by client', dest='default_num_inference_steps') parser_image_generation.add_argument('--max_num_inference_steps', type=int, default=0, help='Max allowed number of inference steps client is allowed to request for a given prompt', dest='max_num_inference_steps') +parser_image_generation.add_argument('--source_loras', default=None, + help='LoRA adapters to apply. Format: alias1=org1/repo1[:alpha],alias2=org2/repo2[@file.safetensors][:alpha],' + 'composite=@alias1:alpha+@alias2:alpha. ' + '@filename specifies which .safetensors file (auto-detected when repo has exactly one). ' + ':alpha sets adapter weight (default 1.0). ' + 'Composite entries (source starts with @) blend multiple adapters. Only for image_generation task.', + dest='source_loras') parser_text2speech = subparsers.add_parser('text2speech', help='export model for text2speech endpoint') add_common_arguments(parser_text2speech) @@ -339,6 +346,17 @@ def add_common_arguments(parser): default_num_inference_steps: {{default_num_inference_steps}},{% endif %} {%- if max_num_inference_steps > 0 %} max_num_inference_steps: {{max_num_inference_steps}},{% endif %} + {%- for lora in lora_adapters %} + lora_adapters { alias: "{{lora.alias}}" path: "{{lora.path}}"{% if lora.alpha is not none %} alpha: {{lora.alpha}}{% endif %} mode: DYNAMIC } + {%- endfor %} + {%- for composite in composite_lora_adapters %} + composite_lora_adapters { + alias: "{{composite.alias}}" + {%- for comp in composite.components %} + components { adapter_alias: "{{comp.adapter_alias}}"{% if comp.alpha != 1.0 %} alpha: {{comp.alpha}}{% endif %} } + {%- endfor %} + } + {%- endfor %} } } }""" @@ -616,7 +634,7 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi add_servable_to_config(config_file_path, model_name, os.path.relpath(os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path))) -def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, num_streams): +def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, num_streams, source_loras): model_path = "./" target_path = os.path.join(model_repository_path, model_name) model_index_path = os.path.join(target_path, 'model_index.json') @@ -629,6 +647,74 @@ def export_image_generation_model(model_repository_path, source_model, model_nam if os.system(optimum_command): raise ValueError("Failed to export image generation model", source_model) + # Download and resolve LoRA adapters + lora_adapters = [] + composite_lora_adapters = [] + if source_loras: + from huggingface_hub import snapshot_download + entries = source_loras.split(',') + for entry in entries: + entry = entry.strip() + if '=' in entry: + alias, source = entry.split('=', 1) + else: + source = entry + alias = entry.split('/')[-1] if '/' in entry else entry + + # Composite LoRA: source starts with @ + if source.startswith('@'): + components = [] + for comp_token in source.split('+'): + comp_token = comp_token.strip().lstrip('@') + if ':' in comp_token: + ref, alpha_str = comp_token.rsplit(':', 1) + alpha = float(alpha_str) + else: + ref = comp_token + alpha = 1.0 + components.append({'adapter_alias': ref, 'alpha': alpha}) + composite_lora_adapters.append({'alias': alias, 'components': components}) + print(f"Composite LoRA: {alias} -> {components}") + continue + + # Parse optional alpha (trailing :float after repo or filename) + alpha = None + repo_and_file = source + # Check for alpha suffix: alias=org/repo:0.8 or alias=org/repo@file.safetensors:0.8 + if ':' in repo_and_file: + last_colon = repo_and_file.rfind(':') + potential_alpha = repo_and_file[last_colon + 1:] + try: + alpha = float(potential_alpha) + repo_and_file = repo_and_file[:last_colon] + except ValueError: + pass # Not an alpha suffix (could be part of URL) + + safetensors_file = '' + if '@' in repo_and_file: + repo, safetensors_file = repo_and_file.rsplit('@', 1) + else: + repo = repo_and_file + lora_dir = os.path.join(target_path, 'loras', repo) + if not os.path.isdir(lora_dir): + print(f"Downloading LoRA adapter: {repo} to {lora_dir}") + snapshot_download(repo_id=repo, local_dir=lora_dir) + else: + print(f"LoRA adapter directory already exists: {lora_dir}") + if not safetensors_file: + st_files = [f for f in os.listdir(lora_dir) if f.endswith('.safetensors')] + if len(st_files) == 0: + raise ValueError(f"No .safetensors files found in LoRA adapter: {repo}") + if len(st_files) > 1: + raise ValueError(f"Multiple .safetensors files in LoRA adapter: {repo}. Use @filename to specify.") + safetensors_file = st_files[0] + lora_path = 'loras/' + repo + '/' + safetensors_file + lora_entry = {'alias': alias, 'path': lora_path, 'alpha': alpha} + lora_adapters.append(lora_entry) + print(f"LoRA adapter: {alias} -> {lora_path}" + (f" (alpha={alpha})" if alpha else "")) + task_parameters['lora_adapters'] = lora_adapters + task_parameters['composite_lora_adapters'] = composite_lora_adapters + plugin_config = {} assert num_streams >= 0, "num_streams should be a non-negative integer" if num_streams > 0: @@ -711,4 +797,4 @@ def export_image_generation_model(model_repository_path, source_model, model_nam 'max_num_inference_steps', 'extra_quantization_params' ]} - export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['num_streams']) + export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['num_streams'], args['source_loras']) diff --git a/demos/image_generation/README.md b/demos/image_generation/README.md index 4943524c84..994ae55151 100644 --- a/demos/image_generation/README.md +++ b/demos/image_generation/README.md @@ -397,7 +397,7 @@ A single servable exposes the following endpoints: > **Note:** Inpainting/outpainting requests are processed sequentially — concurrent requests will be queued. -> **Note:** For inpainting/outpainting, dedicated inpainting models (e.g. `stable-diffusion-v1-5/stable-diffusion-inpainting`) only support the `images/edits` endpoint. Check [supported models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#image-generation-models). +> **Note:** Dedicated inpainting models (e.g. `stable-diffusion-v1-5/stable-diffusion-inpainting`) only support the `images/edits` endpoint — they cannot be used for text-to-image generation via `images/generations`. General-purpose models (e.g. SDXL) support both endpoints. Check [supported models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#image-generation-models). All requests are processed in unary format, with no streaming capabilities. @@ -528,6 +528,12 @@ Output file (`edit_output.png`): Inpainting replaces a masked region in an image based on the prompt. The `mask` is a black-and-white image where white pixels mark the area to repaint. +Download sample images: +```console +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/cat.png +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/cat_mask.png +``` + ![cat](./cat.png) ![cat_mask](./cat_mask.png) ::::{tab-set} @@ -599,6 +605,12 @@ Outpainting extends an image beyond its original borders. Prepare two images: - **outpaint_input.png** — the original image centered on a larger canvas (e.g. 768×768) with black borders - **outpaint_mask.png** — white where the new content should be generated (the borders), black where the original image is +Download sample images: +```console +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/outpaint_input.png +curl -O https://raw.githubusercontent.com/openvinotoolkit/model_server/main/demos/image_generation/outpaint_mask.png +``` + ![outpaint_input](./outpaint_input.png) ![outpaint_mask](./outpaint_mask.png) ::::{tab-set} @@ -718,6 +730,190 @@ ovms --rest_port 8000 ^ Please follow [OpenVINO notebook](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/image-to-image-genai/image-to-image-genai.ipynb) to understand how other parameters affect editing. +## Multi-LoRA Image Generation + +This section demonstrates how to serve multiple LoRA adapters with a single SDXL base model, enabling per-request style selection. This replicates the [Multi LoRA Image Generation notebook](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/multilora-image-generation/multilora-image-generation.ipynb) but using OVMS for serving. + +### Start Server with Multiple LoRA Adapters + +The following command starts OVMS with Stable Diffusion XL and 5 LoRA adapters for different artistic styles: + +::::{tab-set} +:::{tab-item} Docker (Linux) +:sync: docker +```bash +mkdir -p models + +docker run -d --rm --user $(id -u):$(id -g) -p 8000:8000 -v $(pwd)/models:/models/:rw \ + -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy \ + openvino/model_server:latest \ + --rest_port 8000 \ + --model_repository_path /models/ \ + --task image_generation \ + --source_model stabilityai/stable-diffusion-xl-base-1.0 \ + --source_loras "xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors,thepoint=alvdansen/the-point@araminta_k_the_point.safetensors,ukiyo=KappaNeuro/ukiyo-e-art@Ukiyo-e Art.safetensors,vector=DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora@DD-vector-v2.safetensors,chalk=Norod78/sdxl-chalkboarddrawing-lora@SDXL_ChalkBoardDrawing_LoRA_r8.safetensors" +``` +::: + +:::{tab-item} Bare metal (Windows) +:sync: bare-metal +```bat +mkdir models + +ovms --rest_port 8000 ^ + --model_repository_path ./models/ ^ + --task image_generation ^ + --source_model stabilityai/stable-diffusion-xl-base-1.0 ^ + --source_loras "xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors,thepoint=alvdansen/the-point@araminta_k_the_point.safetensors,ukiyo=KappaNeuro/ukiyo-e-art@Ukiyo-e Art.safetensors,vector=DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora@DD-vector-v2.safetensors,chalk=Norod78/sdxl-chalkboarddrawing-lora@SDXL_ChalkBoardDrawing_LoRA_r8.safetensors" +``` +::: + +:::: + +The registered adapters and their recommended use: + +| Alias | Repository | Style | Recommended Weight | Prompt Template | +|-------|-----------|-------|-------------------|-----------------| +| `xray` | DoctorDiffusion/doctor-diffusion-s-xray-xl-lora | X-Ray style | 0.8 | `xray ` | +| `thepoint` | alvdansen/the-point | Artistic illustration | 0.6 | `` | +| `ukiyo` | KappaNeuro/ukiyo-e-art | Ukiyo-e Japanese art | 0.8 | `an illustration of in Ukiyo-e Art style` | +| `vector` | DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora | Vector art | 0.8 | `vector ` | +| `chalk` | Norod78/sdxl-chalkboarddrawing-lora | Chalkboard drawing | 0.45 | `A colorful chalkboard drawing of ` | + +### Generate Images with Different Styles + +Use the adapter alias as the `model` field to select which adapter to apply per request. The adapter is activated via **model name routing** — when the `model` field matches a registered LoRA alias, that adapter is automatically applied. + +**X-Ray style:** +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "xray", + "prompt": "xray a cute cat in sunglasses", + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024" + }' | jq -r '.data[0].b64_json' | base64 --decode > xray_cat.png +``` + +**Ukiyo-e Japanese art:** +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "ukiyo", + "prompt": "an illustration of a cute cat in sunglasses in Ukiyo-e Art style", + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024" + }' | jq -r '.data[0].b64_json' | base64 --decode > ukiyo_cat.png +``` + +**Chalkboard drawing:** +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "chalk", + "prompt": "A colorful chalkboard drawing of a cute cat in sunglasses", + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024" + }' | jq -r '.data[0].b64_json' | base64 --decode > chalk_cat.png +``` + +Optionally override the adapter alpha using `lora_alphas`: +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "xray", + "prompt": "xray a cute cat in sunglasses", + "lora_alphas": {"xray": 0.5}, + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024" + }' | jq -r '.data[0].b64_json' | base64 --decode > xray_cat_half_weight.png +``` +### Using OpenAI Python Client with LoRA + +```python +from openai import OpenAI +import base64 +from io import BytesIO +from PIL import Image + +client = OpenAI( + base_url="http://localhost:8000/v3", + api_key="unused" +) + +# Define LoRA styles — the adapter alias is used as the model name +styles = { + "xray": {"prompt": "xray {subject}"}, + "thepoint": {"prompt": "{subject}"}, + "ukiyo": {"prompt": "an illustration of {subject} in Ukiyo-e Art style"}, + "vector": {"prompt": "vector {subject}"}, + "chalk": {"prompt": "A colorful chalkboard drawing of {subject}"}, +} + +subject = "a cute cat in sunglasses" + +for style_name, style_config in styles.items(): + prompt = style_config["prompt"].format(subject=subject) + response = client.images.generate( + model=style_name, # adapter alias activates the LoRA + prompt=prompt, + extra_body={ + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024", + } + ) + image_data = base64.b64decode(response.data[0].b64_json) + image = Image.open(BytesIO(image_data)) + image.save(f'{style_name}_cat.png') + print(f"Saved {style_name}_cat.png") +``` + +### Blending Multiple Adapters + +To blend multiple adapters, define a **composite adapter** at startup using the `@alias:weight` syntax: + +```bash +--source_loras="xray=...,ukiyo=...,blend=@xray:0.5+@ukiyo:0.4" +``` + +Then use the composite alias as the model name: +```python +response = client.images.generate( + model="blend", # activates both xray and ukiyo + prompt="a cute cat in sunglasses", + extra_body={ + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024", + } +) +``` + +You can override individual component weights at request time: +```python +response = client.images.generate( + model="blend", + prompt="a cute cat in sunglasses", + extra_body={ + "lora_alphas": {"xray": 0.8, "ukiyo": 0.2}, + "num_inference_steps": 20, + "guidance_scale": 0.0, + "size": "1024x1024", + } +) +``` + +> **Note:** For more details on LoRA adapter configuration, see the [Image Generation reference documentation](../../docs/image_generation/reference.md#lora-adapters). + ## References - [Image Generation API](../../docs/model_server_rest_api_image_generation.md) - [Image Edit API](../../docs/model_server_rest_api_image_edit.md) diff --git a/docs/image_generation/reference.md b/docs/image_generation/reference.md index ee4f775b00..d7226ddeea 100644 --- a/docs/image_generation/reference.md +++ b/docs/image_generation/reference.md @@ -67,6 +67,23 @@ Static model resolution settings: - `optional uint64 num_images_per_prompt` - used together with max_resolution, to define batch size in static model shape. - `optional float guidance_scale` - used together with max_resolution +LoRA adapter settings: +- `repeated LoraAdapterEntry lora_adapters` - list of LoRA adapters to load. Each entry defines: + - `required string alias` - unique name used for request routing (the `model` field in API requests) + - `required string path` - path to the `.safetensors` file (absolute, or relative to the graph directory) + - `optional float alpha` - adapter weight/strength [default = 1.0] + - `optional LoraLoadMode mode` - how the adapter is loaded [default = DYNAMIC]. Possible values: + - `DYNAMIC` - adapter is applied/removed at inference time (hot-swap between requests). Used on CPU and GPU. + - `STATIC` - adapter is compiled into the model with fixed alpha at load time. No runtime switching is possible. This is the mode used on NPU. + - `FUSE` - adapter is permanently merged into the base model weights. Always active, not selectable via routing, and irreversible. +- `repeated CompositeLoraAdapterEntry composite_lora_adapters` - composite adapters that blend multiple individual adapters. Each entry defines: + - `required string alias` - composite name used for request routing + - `repeated CompositeLoraComponent components` - list of component adapters with: + - `required string adapter_alias` - reference to a registered `lora_adapters` alias + - `optional float alpha` - component weight [default = 1.0]. Only effective in DYNAMIC mode. + +> **Note:** When using `--source_loras` CLI parameter, the `lora_adapters` and `composite_lora_adapters` fields in `graph.pbtxt` are generated automatically. The `mode` field is set based on the target device: NPU → `STATIC`, everything else → `DYNAMIC`. Manual editing is only needed for advanced configurations like `FUSE` mode. + ## Models Directory In node configuration we set `models_path` indicating location of the directory with files loaded by LLM engine. It loads following files: @@ -165,6 +182,185 @@ We recommend using [export script](../../demos/common/export_models/README.md) t Check [tested models](https://github.com/openvinotoolkit/openvino.genai/blob/master/tests/python_tests/models/real_models). +## LoRA Adapters + +[LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685) adapters allow fine-tuning image generation models without retraining the full model. OVMS supports loading multiple LoRA adapters at startup and dynamically selecting/blending them per request. + +### Registering LoRA Adapters + +LoRA adapters are registered at server startup via the `--source_loras` CLI parameter. The format is a comma-separated list of `alias=source` entries: + +``` +--source_loras=alias1=source1,alias2=source2,... +``` + +**Supported source types:** + +| Source Type | Format | Example | +|------------|--------|---------| +| HuggingFace repo | `org/repo` | `pokemon=juliensimon/sd-pokemon-lora` | +| HuggingFace repo with explicit file | `org/repo@filename.safetensors` | `xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors` | +| Direct URL | `https://...` | `style=https://huggingface.co/user/repo/resolve/main/model.safetensors` | +| Local file path (Linux) | `/path/to/file.safetensors` | `custom=/models/loras/my_style.safetensors` | +| Local file path (Windows) | `C:\path\to\file.safetensors` | `custom=C:\models\loras\my_style.safetensors` | +| Relative local path | `./path/to/file.safetensors` | `custom=./loras/my_style.safetensors` | + +**Source type detection rules:** + +The source type is determined automatically based on the source string: + +1. If the source starts with `https://` or `http://` → **Direct URL** +2. If the source starts with `/` (Unix absolute), `./` or `.\` (relative), or matches `X:\` / `X:/` (Windows drive letter) → **Local file path** +3. Otherwise → **HuggingFace repository** (with optional `@filename` suffix) + +**Default alpha (adapter weight):** + +Each individual adapter can optionally specify a default alpha weight by appending `:alpha` to the source: + +``` +--source_loras="alias=source:alpha" +``` + +The alpha value controls how strongly the adapter influences generation (default: `1.0`). Examples: + +```bash +# Linux - adapter with alpha 0.6 +--source_loras="pokemon=/models/loras/pokemon.safetensors:0.6" + +# Windows - adapter with alpha 0.75 +--source_loras="pokemon=C:\models\loras\pokemon.safetensors:0.75" + +# HuggingFace repo with alpha +--source_loras="pokemon=juliensimon/sd-pokemon-lora:0.8" +``` + +> **Note:** For composite adapters, alpha is specified per-component using the `@ref:alpha` syntax (see [Composite Adapters](#composite-adapters)). The `:alpha` suffix on the source applies only to individual adapters. +> +> **Important:** Alpha must be specified at only one level — either on the individual adapter OR on the composite components, not both. If both have non-default values, the server will reject the configuration with an error. + +**Example:** +```bash +ovms --rest_port 8000 \ + --model_repository_path /models/ \ + --task image_generation \ + --source_model stabilityai/stable-diffusion-xl-base-1.0 \ + --source_loras "xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors,ukiyo=KappaNeuro/ukiyo-e-art@Ukiyo-e Art.safetensors,vector=DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora@DD-vector-v2.safetensors" +``` + +> **Important:** LoRA adapters must be compatible with the base model architecture. For example, SDXL adapters can only be used with an SDXL base model. + +### Composite Adapters + +You can define composite adapters that blend multiple adapters with specified weights: + +``` +--source_loras="pokemon=juliensimon/sd-pokemon-lora,anime=user/anime-lora,mix=@pokemon:0.7+@anime:0.5" +``` + +The `mix` adapter is a composite that blends `pokemon` at weight 0.7 and `anime` at weight 0.5. + +### Per-Request LoRA Selection via Model Name Routing + +Adapter selection is driven by the `model` field in the request. When the `model` field matches a registered adapter alias, that adapter is automatically applied: + +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{"model": "xray", "prompt": "xray a human hand", "num_inference_steps": 20}' +``` + +In this example, `xray` is the alias defined in `--source_loras` (e.g. `xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors`). The adapter is applied with its default weight. + +When the `model` field matches a **composite** adapter alias, all component adapters are activated with their pre-defined weights: + +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{"model": "mix", "prompt": "a landscape"}' +``` + +When the `model` field is the **base model name** (not matching any adapter alias), generation proceeds without any LoRA adapter applied (base model only). + +### Overriding Adapter Alphas with `lora_alphas` + +The `lora_alphas` field in the request body allows overriding the default alpha of the active adapter(s). It does **not** independently select which adapters to activate — adapter selection is always based on the `model` field. + +**Override a single adapter weight:** +```json +{ + "model": "xray", + "prompt": "xray a cute cat in sunglasses", + "lora_alphas": {"xray": 0.5}, + "num_inference_steps": 20 +} +``` + +**Override component weights in a composite adapter:** +```json +{ + "model": "mix", + "prompt": "a landscape in mixed style", + "lora_alphas": {"ukiyo": 0.3, "vector": 0.8} +} +``` + +### Blending Multiple Adapters + +To blend multiple adapters simultaneously, define a **composite adapter** at startup: + +``` +--source_loras="xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors,ukiyo=KappaNeuro/ukiyo-e-art@Ukiyo-e Art.safetensors,blend=@xray:0.5+@ukiyo:0.4" +``` + +Then use the composite alias in requests: +```bash +curl http://localhost:8000/v3/images/generations \ + -H "Content-Type: application/json" \ + -d '{"model": "blend", "prompt": "a cat"}' +``` + +You can override individual component alphas at request time via `lora_alphas`: +```json +{ + "model": "blend", + "prompt": "a cat", + "lora_alphas": {"xray": 0.8, "ukiyo": 0.2} +} +``` + +### LoRA Adapter Modes + +The adapter loading mode determines how LoRA weights interact with the base model. The mode is set automatically to dynamic unless model will use NPU. This can be adjusted manually in `graph.pbtxt`. + +| Mode | Device | Behavior | +|------|--------|----------| +| `DYNAMIC` | CPU, GPU | Default. Adapters are applied/removed per request. Multiple adapters can be hot-swapped. Base model is accessible without any adapter. | +| `STATIC` | NPU (default) | Adapters are compiled into the model at load time with fixed alpha values. No runtime switching — all adapters are always active. Base model is not independently accessible. | +| `FUSE` | Any | Adapter is permanently merged into base weights. Always active, invisible to routing, and irreversible. Only configurable via manual `graph.pbtxt` editing. | + +**DYNAMIC mode (CPU/GPU):** +- Adapters are registered at compile time but activated/deactivated per request based on the `model` field. +- `lora_alphas` in the request body can override adapter strengths at runtime. +- Sending `"model": ""` disables all adapters (pure base model). + +**STATIC mode (NPU):** +- All adapters are compiled with their configured `alpha` and remain active permanently. +- The `alpha` value determines the fixed adapter strength — it cannot be changed at runtime. +- `lora_alphas` in requests is **rejected** — alphas are baked in at compile time and cannot be overridden per request. +- The base model is **not accessible** (always has adapters applied). +- With a single adapter: only the adapter's alias is a valid `model` name. +- With multiple adapters: composites are **required**. Only composite aliases are valid `model` names. +- Alpha source priority: if alpha is specified only at the individual adapter level, it is used. If alpha is specified only at the composite component level (individual stays at default 1.0), the composite alpha is used for compilation. Specifying alpha at both levels is an error. + +**FUSE mode:** +- The adapter is merged into base weights during model compilation using `MODE_FUSE`. +- It is always active — the base model without the adapter is **not accessible**. +- Does not appear in the list of routable adapters and cannot be selected or deselected via the `model` field. +- Typically combined with DYNAMIC adapters: the FUSE adapter permanently enhances the base, while DYNAMIC adapters can be hot-swapped on top. +- Only configurable via manual `graph.pbtxt` editing. + +> **Important:** STATIC mode is automatically applied when targeting NPU via `--source_loras`. On NPU with multiple LoRAs, composite definitions are mandatory to define the routing aliases. The `alpha` specified per adapter in `--source_loras` (e.g., `pokemon=org/repo:0.8`) is the compile-time weight that gets permanently baked into the model. + ## References - [Image Generation API](../model_server_rest_api_image_generation.md) - [Image Edit API](../model_server_rest_api_image_edit.md) diff --git a/docs/model_server_rest_api_image_generation.md b/docs/model_server_rest_api_image_generation.md index 1e25123214..928433511c 100644 --- a/docs/model_server_rest_api_image_generation.md +++ b/docs/model_server_rest_api_image_generation.md @@ -70,6 +70,7 @@ curl http://localhost:8000/v3/images/generations \ | strength | ❌ | ❌ | float (optional) min: 0.0, max: 1.0 | **Only for [image editing](./model_server_rest_api_image_edit.md) endpoints.** Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image` parameter. | | rng_seed | ✅ | ❌ | integer (optional) | Seed for random generator. | | max_sequence_length | ✅ | ❌ | integer (optional) | This parameters limits max sequence length for T5 encoder for SD3 and FLUX models. T5 tokenizer output is padded with pad tokens to 'max_sequence_length' within a pipeline. So, for better performance, you can specify this parameter to lower value to speed-up T5 encoder inference as well as inference of transformer denoising model. For optimal performance it can be set to a number of tokens for `prompt_3` / `negative_prompt_3` for SD3 or `prompt_2` for FLUX. | +| lora_alphas | ✅ | ❌ | object (optional) | A JSON object mapping LoRA adapter aliases to alpha overrides for the active adapter(s). Adapter selection is driven by the `model` field (model name routing). This field only overrides the alphas of adapters already activated by model name routing — it does not independently select adapters. Aliases must match adapters registered via `--source_loras` at server start. Example: `{"pokemon": 0.5}` (when `model` is `"pokemon"` or a composite containing it). See [LoRA Adapters documentation](./image_generation/reference.md#lora-adapters). | ## Response diff --git a/docs/pull_hf_models.md b/docs/pull_hf_models.md index 86860a909e..77ab0dfb60 100644 --- a/docs/pull_hf_models.md +++ b/docs/pull_hf_models.md @@ -94,6 +94,29 @@ In case you want to setup model and start server in one step, follow [instructio > **Note:** When using pull mode you need both read and write access rights to models repository. +## Pulling Image Generation Models with LoRA Adapters + +For image generation tasks, you can additionally specify LoRA adapters to be downloaded alongside the base model using the `--source_loras` parameter: + +```text +ovms --rest_port 8000 \ + --model_repository_path /models/ \ + --task image_generation \ + --source_model stabilityai/stable-diffusion-xl-base-1.0 \ + --source_loras "xray=DoctorDiffusion/doctor-diffusion-s-xray-xl-lora@DD-xray-v1.safetensors,ukiyo=KappaNeuro/ukiyo-e-art@Ukiyo-e Art.safetensors" +``` + +The `--source_loras` format is a comma-separated list of `alias=source[:alpha]` entries. Supported source types: +- HuggingFace repository: `alias=org/repo` or `alias=org/repo@filename.safetensors` +- Direct URL: `alias=https://url/to/file.safetensors` +- Local file (Linux): `alias=/path/to/file.safetensors` +- Local file (Windows): `alias=C:\path\to\file.safetensors` +- Relative local file: `alias=./path/to/file.safetensors` + +Each adapter can optionally specify a default alpha weight: `alias=source:0.7` (default: `1.0`). + +For more details, see the [LoRA Adapters documentation](./image_generation/reference.md#lora-adapters). + ## Resuming an interrupted pull Pulling Generative AI models from Hugging Face often involves transferring multi-gigabyte LFS files (e.g. `openvino_model.bin`). To make this robust against network errors and operator interventions, OVMS pull mode persists the in-progress download state on disk and resumes from where it stopped on the next `--pull` invocation. No extra flags are required — simply re-run the same `--pull` command against the same `--model_repository_path` and OVMS will continue any partially downloaded LFS files instead of starting from scratch. @@ -147,4 +170,3 @@ On startup OVMS logs the resolved configuration, e.g.: ``` > **Note:** Resume relies on the remote server honoring HTTP `Range` requests. Hugging Face Hub supports this by default; private mirrors must allow ranged GETs for resume to work. - diff --git a/docs/starting_server.md b/docs/starting_server.md index d5f2cb16f9..44010498d2 100644 --- a/docs/starting_server.md +++ b/docs/starting_server.md @@ -39,6 +39,11 @@ Server will detect the type of requested servable (classic model, generative mod When the model is generative, like copied from Hugging Face or exported using optimum-cli, all the pipeline runtime parameters can be defined with --tasks followed by task specific options. +::::{tab-set} +:::{tab-item} With Docker +:sync: docker +**Required:** Docker Engine installed + ```text docker run -d --rm -v ${PWD}/:/model -p 8000:8000 openvino/model_server:latest \ --model_path /model --model_name --rest_port 8000 --log_level DEBUG \ diff --git a/src/BUILD b/src/BUILD index 8f33d63a62..af756165a8 100644 --- a/src/BUILD +++ b/src/BUILD @@ -2567,6 +2567,7 @@ cc_test( ":pull_gguf_hf_model_test", ":listdirectorymodels_test", ":graph_export_test", + ":lora_graph_export_test", ":config_export_test", ":config_export_full_test", ":test_constructor_enabled_model_manager", @@ -2850,6 +2851,7 @@ cc_library( "//src/pull_module:hf_pull_model_module", "//src:ovms_lib", "libovmsstring_utils", + "@cpp_httplib//:cpp_httplib", "@com_google_googletest//:gtest", ], local_defines = COMMON_LOCAL_DEFINES, @@ -2909,6 +2911,23 @@ cc_library( local_defines = COMMON_LOCAL_DEFINES, copts = COPTS_TESTS, ) +cc_library( + name = "lora_graph_export_test", + linkstatic = 1, + alwayslink = True, + srcs = ["test/lora_graph_export_test.cpp"], + linkopts = [], + deps = [ + ":test_light_test_utils", + ":test_test_with_temp_dir", + "//src/graph_export:graph_export", + "//src/graph_export:image_generation_graph_cli_parser", + "//src:libovms_server_settings", + "@com_google_googletest//:gtest", + ], + local_defines = COMMON_LOCAL_DEFINES, + copts = COPTS_TESTS, +) cc_library( name = "config_export_test", linkstatic = 1, diff --git a/src/capi_frontend/server_settings.hpp b/src/capi_frontend/server_settings.hpp index 2c0a364bb9..8e4e4132d1 100644 --- a/src/capi_frontend/server_settings.hpp +++ b/src/capi_frontend/server_settings.hpp @@ -144,6 +144,36 @@ struct RerankGraphSettingsImpl { uint64_t maxAllowedChunks = 10000; }; +enum class LoraSourceType { + HF_REPO, + DIRECT_URL, + LOCAL_FILE +}; + +struct LoraAdapterSettings { + std::string alias; + std::string sourceLora; // HF repo, direct URL, or local file path + std::optional safetensorsFile; // user-specified filename (via @filename, extracted from URL/path) + LoraSourceType sourceType = LoraSourceType::HF_REPO; + std::optional alpha; // user-specified adapter weight; std::nullopt = use default (1.0) + std::optional resolvedSafetensorsFile; // auto-resolved by HF API during pull + + // Returns the effective filename: user-specified takes priority over auto-resolved. + const std::optional& effectiveSafetensorsFile() const { + return safetensorsFile.has_value() ? safetensorsFile : resolvedSafetensorsFile; + } +}; + +struct CompositeLoraComponent { + std::string adapterAlias; // references a LoraAdapterSettings alias + float alpha = 1.0f; +}; + +struct CompositeLoraSettings { + std::string alias; + std::vector components; +}; + struct ImageGenerationGraphSettingsImpl { std::string resolution = ""; std::string maxResolution = ""; @@ -153,6 +183,8 @@ struct ImageGenerationGraphSettingsImpl { std::optional maxNumberImagesPerPrompt; std::optional defaultNumInferenceSteps; std::optional maxNumInferenceSteps; + std::vector loraAdapters; + std::vector compositeLoraAdapters; }; struct ExportSettings { @@ -170,6 +202,7 @@ struct HFSettingsImpl { std::string sourceModel = ""; std::optional ggufFilename; std::string downloadPath = ""; + std::optional sourceLoras; bool overwriteModels = false; ModelDownlaodType downloadType = GIT_CLONE_DOWNLOAD; GraphExportType task = UNKNOWN_GRAPH; diff --git a/src/cli_parser.cpp b/src/cli_parser.cpp index fc0545219f..5bc0a6c9df 100644 --- a/src/cli_parser.cpp +++ b/src/cli_parser.cpp @@ -213,6 +213,10 @@ std::variant> CLIParser::parse(int argc, char* "HF source model path", cxxopts::value(), "HF_SOURCE") + ("source_loras", + "LoRA adapters for image generation. Format: alias1=org1/repo1,alias2=org2/repo2@file.safetensors,alias3=https://url/file.safetensors,alias4=/local/path/file.safetensors", + cxxopts::value(), + "SOURCE_LORAS") ("gguf_filename", "Name of the GGUF file", cxxopts::value(), @@ -734,6 +738,9 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& // (when model_path is set, user wants to use local model without HF pull) hfSettings.sourceModel = result->operator[]("model_name").as(); } + if (result->count("source_loras")) { + hfSettings.sourceLoras = result->operator[]("source_loras").as(); + } if ((result->count("weight-format") || result->count("extra_quantization_params")) && isOptimumCliDownload(hfSettings.sourceModel, hfSettings.ggufFilename)) { hfSettings.downloadType = OPTIMUM_CLI_DOWNLOAD; } diff --git a/src/graph_export/graph_export.cpp b/src/graph_export/graph_export.cpp index 64d80bc23c..d1cd30b7b8 100644 --- a/src/graph_export/graph_export.cpp +++ b/src/graph_export/graph_export.cpp @@ -509,6 +509,47 @@ node: { max_num_inference_steps: )" << graphSettings.maxNumInferenceSteps.value(); } + bool targetIsNPU = exportSettings.targetDevice.find("NPU") != std::string::npos; + + for (const auto& adapter : graphSettings.loraAdapters) { + std::string loraPath; + if (adapter.sourceType == LoraSourceType::LOCAL_FILE) { + loraPath = adapter.sourceLora; + } else if (!adapter.effectiveSafetensorsFile().has_value()) { + SPDLOG_ERROR("LoRA adapter '{}': safetensors filename not resolved. " + "For HF repos, use @filename syntax (e.g. org/repo@weights.safetensors) or run with --pull to auto-resolve.", + adapter.alias); + return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; + } else if (adapter.sourceType == LoraSourceType::HF_REPO) { + loraPath = "loras/" + adapter.sourceLora + "/" + adapter.effectiveSafetensorsFile().value(); + } else { // cURL direct link + loraPath = "loras/" + adapter.alias + "/" + adapter.effectiveSafetensorsFile().value(); + } + oss << R"( + lora_adapters { alias: ")" << adapter.alias << R"(" path: ")" << loraPath << R"(")"; + if (adapter.alpha.has_value()) { + oss << R"( alpha: )" << adapter.alpha.value(); + } + oss << (targetIsNPU ? R"( mode: STATIC)" : R"( mode: DYNAMIC)"); + oss << R"( })"; + } + + for (const auto& composite : graphSettings.compositeLoraAdapters) { + oss << R"( + composite_lora_adapters { + alias: ")" << composite.alias << R"(" +)"; + for (const auto& component : composite.components) { + oss << R"( components { adapter_alias: ")" << component.adapterAlias << R"(")"; + if (component.alpha != 1.0f) { + oss << R"( alpha: )" << component.alpha; + } + oss << R"( } +)"; + } + oss << R"( })"; + } + oss << R"( } } diff --git a/src/graph_export/image_generation_graph_cli_parser.cpp b/src/graph_export/image_generation_graph_cli_parser.cpp index ed0d1b91ef..74ae672831 100644 --- a/src/graph_export/image_generation_graph_cli_parser.cpp +++ b/src/graph_export/image_generation_graph_cli_parser.cpp @@ -16,20 +16,37 @@ #include "image_generation_graph_cli_parser.hpp" #include +#include +#include #include #include #include +#include #include #include #include #include #include "../capi_frontend/server_settings.hpp" +#include "src/logging.hpp" #include "../ovms_exit_codes.hpp" #include "../status.hpp" +#include "../stringutils.hpp" namespace ovms { +static bool isValidLoraAlias(const std::string& alias) { + if (alias.empty()) { + return false; + } + for (char c : alias) { + if (!std::isalnum(static_cast(c)) && c != '-' && c != '_' && c != '.') { + return false; + } + } + return true; +} + static bool isValidResolution(const std::string& resolution) { static const std::regex pattern(R"(\d+x\d+)"); return std::regex_match(resolution, pattern); @@ -164,6 +181,159 @@ void ImageGenerationGraphCLIParser::prepare(ServerSettingsImpl& serverSettings, } } + // Parse --source_loras + // Supports three source types plus composite aliases: + // alias=org/repo (HF_REPO) + // alias=org/repo@file.safetensors (HF_REPO with explicit file) + // alias=https://url/file.safetensors (DIRECT_URL) + // alias=/path/to/file.safetensors (LOCAL_FILE) + // alias=@ref1:0.7+@ref2:0.5 (COMPOSITE - references other aliases) + if (hfSettings.sourceLoras.has_value() && !hfSettings.sourceLoras.value().empty()) { + auto entries = ovms::tokenize(hfSettings.sourceLoras.value(), ','); + // First pass: collect all real adapters + for (const auto& entry : entries) { + auto eqPos = entry.find('='); + if (eqPos == std::string::npos) { + throw std::invalid_argument("Missing alias in --source_loras entry: '" + entry + "'. Expected format: alias=source"); + } + std::string alias = entry.substr(0, eqPos); + std::string source = entry.substr(eqPos + 1); + if (alias.empty() || source.empty()) { + throw std::invalid_argument("Invalid --source_loras entry: '" + entry + "'. Alias and source must not be empty."); + } + if (!isValidLoraAlias(alias)) { + throw std::invalid_argument("Invalid LoRA alias '" + alias + "' in --source_loras entry: '" + entry + + "'. Alias must contain only alphanumeric characters, hyphens, underscores, or dots."); + } + // Skip composite entries in first pass + if (source[0] == '@') { + continue; + } + + LoraAdapterSettings adapter; + adapter.alias = alias; + // Parse optional :alpha suffix + auto lastColon = source.rfind(':'); + if (lastColon != std::string::npos && lastColon > 1) { + std::string alphaStr = source.substr(lastColon + 1); + // Skip protocol colons (https:// or http://) + if (alphaStr.substr(0, 2) != "//") { + auto alpha = ovms::stof(alphaStr); + if (!alpha.has_value()) { + throw std::invalid_argument("Invalid alpha value '" + alphaStr + "' in --source_loras entry: '" + entry + "'"); + } + adapter.alpha = alpha.value(); + source = source.substr(0, lastColon); + } + } + // Detect source type + if (source.substr(0, 8) == "https://" || source.substr(0, 7) == "http://") { + adapter.sourceType = LoraSourceType::DIRECT_URL; + adapter.sourceLora = source; + SPDLOG_DEBUG("LoRA '{}': detected source type DIRECT_URL (source: {})", alias, source); + auto lastSlash = source.rfind('/'); + if (lastSlash == std::string::npos || lastSlash == source.size() - 1) { + throw std::invalid_argument("Cannot extract filename from URL in --source_loras entry: '" + entry + "'"); + } + adapter.safetensorsFile = source.substr(lastSlash + 1); + if (!endsWith(adapter.safetensorsFile.value(), ".safetensors")) { + throw std::invalid_argument("URL must point to a .safetensors file in --source_loras entry: '" + entry + "'"); + } + } else if (ovms::isLocalFilePath(source)) { + adapter.sourceType = LoraSourceType::LOCAL_FILE; + adapter.sourceLora = source; + SPDLOG_DEBUG("LoRA '{}': detected source type LOCAL_FILE (source: {})", alias, source); + if (!endsWith(source, ".safetensors")) { + throw std::invalid_argument("Local path must point to a .safetensors file in --source_loras entry: '" + entry + "'"); + } + if (!std::filesystem::exists(source)) { + throw std::invalid_argument("Local LoRA file does not exist: '" + source + "' in --source_loras entry: '" + entry + "'"); + } + auto lastSlash = source.find_last_of("/\\"); + adapter.safetensorsFile = (lastSlash != std::string::npos) ? source.substr(lastSlash + 1) : source; + } else { + adapter.sourceType = LoraSourceType::HF_REPO; + SPDLOG_DEBUG("LoRA '{}': detected source type HF_REPO (source: {})", alias, source); + auto atPos = source.find('@'); + if (atPos != std::string::npos) { + adapter.sourceLora = source.substr(0, atPos); + adapter.safetensorsFile = source.substr(atPos + 1); + if (adapter.safetensorsFile.value().empty()) { + throw std::invalid_argument("Empty filename after @ in --source_loras entry: '" + entry + "'"); + } + } else { + adapter.sourceLora = source; + } + if (adapter.sourceLora.empty()) { + throw std::invalid_argument("Invalid --source_loras entry: '" + entry + "'. HF repo source must not be empty."); + } + } + imageGenerationGraphSettings.loraAdapters.push_back(std::move(adapter)); + } + + // Collect known adapter aliases for validation + std::set knownAliases; + for (const auto& adapter : imageGenerationGraphSettings.loraAdapters) { + knownAliases.insert(adapter.alias); + } + + // Second pass: parse composite entries (source starts with @) + for (const auto& entry : entries) { + auto eqPos = entry.find('='); + std::string alias = entry.substr(0, eqPos); + std::string source = entry.substr(eqPos + 1); + if (source[0] != '@') { + continue; + } + CompositeLoraSettings composite; + composite.alias = alias; + // Parse @ref1:0.7+@ref2:0.5 + auto componentTokens = ovms::tokenize(source, '+'); + for (const auto& compToken : componentTokens) { + if (compToken.empty() || compToken[0] != '@') { + throw std::invalid_argument("Invalid composite LoRA component '" + compToken + "' in entry: '" + entry + "'. Each component must start with @"); + } + CompositeLoraComponent component; + std::string ref = compToken.substr(1); // strip @ + auto colonPos = ref.find(':'); + if (colonPos != std::string::npos) { + component.adapterAlias = ref.substr(0, colonPos); + std::string alphaStr = ref.substr(colonPos + 1); + try { + component.alpha = std::stof(alphaStr); + } catch (...) { + throw std::invalid_argument("Invalid alpha '" + alphaStr + "' in composite LoRA component: '" + compToken + "'"); + } + } else { + component.adapterAlias = ref; + } + if (component.adapterAlias.empty()) { + throw std::invalid_argument("Empty adapter reference in composite LoRA component: '" + compToken + "'"); + } + if (knownAliases.find(component.adapterAlias) == knownAliases.end()) { + throw std::invalid_argument("Composite LoRA references unknown adapter '" + component.adapterAlias + "' in entry: '" + entry + "'"); + } + composite.components.push_back(std::move(component)); + } + if (composite.components.empty()) { + throw std::invalid_argument("Composite LoRA entry has no components: '" + entry + "'"); + } + imageGenerationGraphSettings.compositeLoraAdapters.push_back(std::move(composite)); + } + } + + // NPU + LoRA validation: all adapters are compiled with STATIC mode, no runtime switching. + // Multiple LoRAs on NPU require composites to define the static blend ratios. + bool targetHasNPU = hfSettings.exportSettings.targetDevice.find("NPU") != std::string::npos; + if (targetHasNPU && !imageGenerationGraphSettings.loraAdapters.empty()) { + if (imageGenerationGraphSettings.loraAdapters.size() > 1 && imageGenerationGraphSettings.compositeLoraAdapters.empty()) { + throw std::invalid_argument( + "NPU device with multiple LoRA adapters requires composite definitions to specify " + "blend ratios. All adapters are loaded with STATIC mode and runtime switching is unavailable. " + "Add composite entries to --source_loras or use a single adapter."); + } + } + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); } diff --git a/src/http_payload.hpp b/src/http_payload.hpp index b4415a1616..1a8465d8f7 100644 --- a/src/http_payload.hpp +++ b/src/http_payload.hpp @@ -32,6 +32,7 @@ namespace ovms { struct HttpPayload { std::string uri; + std::string modelName; // resolved model name from request (JSON model field, multipart, or URI) std::unordered_map headers; std::string body; // always std::shared_ptr parsedJson; // pre-parsed body = null diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index 69dee2e4a2..31939a3df5 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -579,6 +579,7 @@ static Status createV3HttpPayload( request.body = request_body; request.parsedJson = std::move(parsedJson); request.uri = std::string(uri); + request.modelName = modelName; request.client = std::make_shared(serverReaderWriter); request.multipartParser = std::move(multiPartParser); diff --git a/src/image_gen/http_image_gen_calculator.cc b/src/image_gen/http_image_gen_calculator.cc index 4e4381f2a6..586cbd6f18 100644 --- a/src/image_gen/http_image_gen_calculator.cc +++ b/src/image_gen/http_image_gen_calculator.cc @@ -30,6 +30,7 @@ #include "pipelines.hpp" #include "imagegenutils.hpp" +#include #pragma warning(push) #pragma warning(disable : 6001 4324 6385 6386) @@ -45,6 +46,68 @@ using ImageGenerationPipelinesMap = std::unordered_map& loraAdapters, + const std::unordered_map>>& compositeLoraAdapters, + const ImageGenPipelineArgs& args, + ov::AnyMap& requestOptions, + const std::unordered_map& loraAlphasOverride = {}) { + if (loraAdapters.empty()) { + return; + } + // Adapters are registered at compile time with their configured alpha values. + // At generate time we explicitly build the adapter config: + // - If modelName matches a composite alias: activate all component adapters with their default alphas. + // - If modelName matches a single adapter alias: activate that adapter with its configured alpha. + // - Otherwise: disable all adapters (alpha=0) so the base model runs clean. + // lora_alphas from request body can override default alphas. + ov::genai::AdapterConfig adapterConfig; + + auto compositeIt = compositeLoraAdapters.find(modelName); + if (compositeIt != compositeLoraAdapters.end()) { + // Composite adapter — activate multiple adapters + for (const auto& [compAlias, defaultAlpha] : compositeIt->second) { + auto adapterIt = loraAdapters.find(compAlias); + if (adapterIt == loraAdapters.end()) { + SPDLOG_LOGGER_WARN(llm_calculator_logger, "Composite LoRA '{}' references unknown adapter '{}', skipping", modelName, compAlias); + continue; + } + float alpha = defaultAlpha; + auto overrideIt = loraAlphasOverride.find(compAlias); + if (overrideIt != loraAlphasOverride.end()) { + alpha = overrideIt->second; + } + adapterConfig.add(adapterIt->second, alpha); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Composite LoRA '{}': applied adapter '{}' with alpha: {}", modelName, compAlias, alpha); + } + } else { + auto adapterIt = loraAdapters.find(modelName); + if (adapterIt != loraAdapters.end()) { + float alpha = DEFAULT_ALPHA; + auto overrideIt = loraAlphasOverride.find(modelName); + if (overrideIt != loraAlphasOverride.end()) { + alpha = overrideIt->second; + } else { + for (const auto& info : args.loraAdapters) { + if (info.alias == modelName) { + alpha = info.alpha; + break; + } + } + } + adapterConfig.add(adapterIt->second, alpha); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Applied LoRA adapter: {} with alpha: {}", modelName, alpha); + } else { + // Disable all adapters that were registered at compile time + for (const auto& [alias, adapter] : loraAdapters) { + adapterConfig.add(adapter, 0.0f); + } + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No LoRA adapter matched for model: {}, disabling all adapters", modelName); + } + } + requestOptions[ov::genai::adapters.name()] = adapterConfig; +} + static bool progress_bar(size_t step, size_t num_steps, ov::Tensor&) { SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Image Generation Step: {}/{}", step + 1, num_steps); return false; @@ -177,12 +240,43 @@ class ImageGenCalculator : public CalculatorBase { } SET_OR_RETURN(std::string, prompt, getPromptField(*payload.parsedJson)); - SET_OR_RETURN(ov::AnyMap, requestOptions, getImageGenerationRequestOptions(*payload.parsedJson, pipe->args)); + bool hasDynamicAdapters = !pipe->loraAdapters.empty() && !pipe->npuLoraStaticMode; + SET_OR_RETURN(ov::AnyMap, requestOptions, getImageGenerationRequestOptions(*payload.parsedJson, pipe->args, hasDynamicAdapters)); + + // Parse optional lora_alphas from request body + auto loraAlphasOverride = ovms::parseLoraAlphasOverride(*payload.parsedJson); + // Apply LoRA adapter if the requested model name matches an alias. + // Under NPU MODE_STATIC adapters are always active — reject requests + // that don't target a valid LoRA alias since the base model is unavailable. + if (pipe->npuLoraStaticMode) { + if (!pipe->loraAdapters.empty()) { + if (!pipe->compositeLoraAdapters.empty()) { + // Multi-LoRA NPU: only composite aliases are valid targets + if (pipe->compositeLoraAdapters.find(payload.modelName) == pipe->compositeLoraAdapters.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Model '", payload.modelName, "' uses NPU with statically fused LoRA adapters. " + "Send requests to the composite LoRA alias name instead.")); + } + } else { + // Single LoRA NPU: only the individual alias is a valid target + if (pipe->loraAdapters.find(payload.modelName) == pipe->loraAdapters.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Model '", payload.modelName, "' uses NPU with statically fused LoRA. " + "Send requests to the LoRA alias name instead.")); + } + } + } + } else { + applyLoraAdapterIfNeeded(payload.modelName, pipe->loraAdapters, pipe->compositeLoraAdapters, pipe->args, requestOptions, loraAlphasOverride); + } if (!pipe->text2ImagePipeline) return absl::FailedPreconditionError("Text-to-image pipeline is not available for this model"); - auto t2i = pipe->text2ImagePipeline->clone(); - auto status = generateTensor(t2i, prompt, requestOptions, images); + absl::Status status; + { + auto t2i = pipe->text2ImagePipeline->clone(); + status = generateTensor(t2i, prompt, requestOptions, images); + } if (!status.ok()) { return status; } @@ -201,7 +295,32 @@ class ImageGenCalculator : public CalculatorBase { return status; } - SET_OR_RETURN(ov::AnyMap, requestOptions, getImageEditRequestOptions(*payload.multipartParser, pipe->args)); + SET_OR_RETURN(ov::AnyMap, requestOptions, getImageEditRequestOptions(*payload.multipartParser, pipe->args, !pipe->loraAdapters.empty() && !pipe->npuLoraStaticMode)); + + // Apply LoRA adapter if the requested model name matches an alias. + // Under NPU MODE_STATIC adapters are always active — reject requests + // that don't target a valid LoRA alias since the base model is unavailable. + if (pipe->npuLoraStaticMode) { + if (!pipe->loraAdapters.empty()) { + if (!pipe->compositeLoraAdapters.empty()) { + // Multi-LoRA NPU: only composite aliases are valid targets + if (pipe->compositeLoraAdapters.find(payload.modelName) == pipe->compositeLoraAdapters.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Model '", payload.modelName, "' uses NPU with statically fused LoRA adapters. " + "Send requests to the composite LoRA alias name instead.")); + } + } else { + // Single LoRA NPU: only the individual alias is a valid target + if (pipe->loraAdapters.find(payload.modelName) == pipe->loraAdapters.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Model '", payload.modelName, "' uses NPU with statically fused LoRA. " + "Send requests to the LoRA alias name instead.")); + } + } + } + } else { + applyLoraAdapterIfNeeded(payload.modelName, pipe->loraAdapters, pipe->compositeLoraAdapters, pipe->args, requestOptions); + } SET_OR_RETURN(std::optional, mask, getFileFromPayload(*payload.multipartParser, "mask")); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Mask present: {}", cc->NodeName(), mask.has_value() && !mask.value().empty()); @@ -218,14 +337,16 @@ class ImageGenCalculator : public CalculatorBase { return status; } SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Inpainting: mask tensor decoded, acquiring inpainting queue slot", cc->NodeName()); - InpaintingQueueGuard inpaintingGuard(*pipe->inpaintingQueue); + PipelineSlotGuard inpaintingGuard(*pipe->inpaintingQueue); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Inpainting: queue slot acquired, invoking generate()", cc->NodeName()); status = generateTensorInpainting(*pipe->inpaintingPipeline, prompt, imageTensor, maskTensor, requestOptions, images); } else { if (!pipe->image2ImagePipeline) return absl::FailedPreconditionError("Image-to-image pipeline is not available for this model"); - auto i2i = pipe->image2ImagePipeline->clone(); - status = generateTensorImg2Img(i2i, prompt, imageTensor, requestOptions, images); + { + auto i2i = pipe->image2ImagePipeline->clone(); + status = generateTensorImg2Img(i2i, prompt, imageTensor, requestOptions, images); + } } if (!status.ok()) { return status; diff --git a/src/image_gen/image_gen_calculator.proto b/src/image_gen/image_gen_calculator.proto index c69f3ce97e..025259aa98 100644 --- a/src/image_gen/image_gen_calculator.proto +++ b/src/image_gen/image_gen_calculator.proto @@ -43,4 +43,33 @@ message ImageGenCalculatorOptions { optional string resolution = 9; optional int64 num_images_per_prompt = 10; optional float guidance_scale = 11; + + // LoRA adapters + repeated LoraAdapterEntry lora_adapters = 12; + + // Composite LoRA adapters (multi-LoRA presets) + repeated CompositeLoraAdapterEntry composite_lora_adapters = 13; +} + +enum LoraLoadMode { + DYNAMIC = 0; // Default: apply/remove at inference time (hot-swap between requests) + STATIC = 1; // Compile into model with fixed alpha (NPU - no runtime switching) + FUSE = 2; // Permanently merge into base weights (irreversible, always active) +} + +message LoraAdapterEntry { + required string alias = 1; + required string path = 2; + optional float alpha = 3 [default = 1.0]; + optional LoraLoadMode mode = 4 [default = DYNAMIC]; +} + +message CompositeLoraComponent { + required string adapter_alias = 1; + optional float alpha = 2 [default = 1.0]; +} + +message CompositeLoraAdapterEntry { + required string alias = 1; + repeated CompositeLoraComponent components = 2; } diff --git a/src/image_gen/image_gen_node_initializer.cpp b/src/image_gen/image_gen_node_initializer.cpp index 72e680fb72..28047cc80a 100644 --- a/src/image_gen/image_gen_node_initializer.cpp +++ b/src/image_gen/image_gen_node_initializer.cpp @@ -73,6 +73,22 @@ class ImageGenNodeInitializer : public NodeInitializer { return StatusCode::INTERNAL_ERROR; } imageGenPipelinesMap.insert(std::pair>(nodeName, std::move(servable))); + // Register LoRA aliases for routing + const auto& args = std::get(statusOrArgs); + bool hasFusedOrStaticAdapter = false; + for (const auto& adapter : args.loraAdapters) { + sidePackets.loraAliases.push_back(adapter.alias); + if (adapter.mode == LoraLoadMode::FUSE || adapter.mode == LoraLoadMode::STATIC) { + hasFusedOrStaticAdapter = true; + } + } + for (const auto& [compositeAlias, components] : args.compositeLoraAdapters) { + sidePackets.loraAliases.push_back(compositeAlias); + } + // When any adapter is fused/static, the base model is not independently usable + if (hasFusedOrStaticAdapter) { + sidePackets.hideBaseModelInRouting = true; + } return StatusCode::OK; } }; diff --git a/src/image_gen/imagegen_init.cpp b/src/image_gen/imagegen_init.cpp index d25bf9dfdd..7cce0e6dfa 100644 --- a/src/image_gen/imagegen_init.cpp +++ b/src/image_gen/imagegen_init.cpp @@ -17,11 +17,10 @@ #include #include +#include #include #include -#include - #include "absl/strings/str_replace.h" #include "absl/strings/ascii.h" @@ -116,6 +115,91 @@ static std::variant> getListOfResolutions(cons return result; } +// Validates LoRA adapter configuration: alpha consistency between individual and composite levels, +// and NPU-specific constraints (composites required for multi-LoRA, all adapters referenced). +// May modify args.loraAdapters[].alpha when composite alpha overrides individual default. +static Status validateLoraAdapterConfig(ImageGenPipelineArgs& args, bool isNPU) { + // Alpha validation: + // Alpha must come from exactly one source per adapter: + // - If only the individual adapter specifies non-default alpha → use it. + // - If only composite components specify non-default alpha → use it (override individual default). + // - If BOTH specify non-default alpha → error (ambiguous). + if (!args.compositeLoraAdapters.empty()) { + // Collect composite-level alphas per adapter alias + std::unordered_map compositeAlphaForAdapter; + for (const auto& [compositeAlias, components] : args.compositeLoraAdapters) { + for (const auto& [compAlias, compAlpha] : components) { + auto it = compositeAlphaForAdapter.find(compAlias); + if (isNPU && it != compositeAlphaForAdapter.end() && it->second != compAlpha) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, + "NPU device: LoRA adapter '{}' is referenced by multiple composites with different alphas " + "({} vs {}). In STATIC mode only one compile-time alpha per adapter is possible.", + compAlias, it->second, compAlpha); + return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; + } + compositeAlphaForAdapter[compAlias] = compAlpha; + } + } + + for (auto& adapter : args.loraAdapters) { + auto compIt = compositeAlphaForAdapter.find(adapter.alias); + if (compIt == compositeAlphaForAdapter.end()) { + continue; + } + bool adapterHasNonDefaultAlpha = adapter.alpha != DEFAULT_ALPHA; + bool compositeHasNonDefaultAlpha = (compIt->second != DEFAULT_ALPHA); + + if (adapterHasNonDefaultAlpha && compositeHasNonDefaultAlpha) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, + "LoRA adapter '{}' has alpha={} and is also referenced in composite_lora_adapters with alpha={}. " + "Cannot specify alpha at both individual and composite level. " + "Set alpha on one level only.", + adapter.alias, adapter.alpha, compIt->second); + return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; + } + + if (!adapterHasNonDefaultAlpha && compositeHasNonDefaultAlpha) { + // Composite provides the alpha — apply it to the adapter for compile time + adapter.alpha = compIt->second; + SPDLOG_LOGGER_INFO(modelmanager_logger, + "LoRA adapter '{}': using composite component alpha={} for compilation.", + adapter.alias, compIt->second); + } + } + } + + // NPU + LoRA validation: NPU uses MODE_STATIC (fixed alpha at compile time), so runtime + // adapter switching is impossible. Multiple LoRAs require a composite definition. + if (isNPU && !args.loraAdapters.empty()) { + if (args.loraAdapters.size() > 1 && args.compositeLoraAdapters.empty()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, + "NPU device with multiple LoRA adapters requires composite_lora_adapters definition. " + "All adapters are compiled with MODE_STATIC and runtime switching is unavailable."); + return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; + } + // When composites exist on NPU, every adapter must appear in at least one composite + if (!args.compositeLoraAdapters.empty()) { + std::set referencedAdapters; + for (const auto& [alias, components] : args.compositeLoraAdapters) { + for (const auto& [compAlias, alpha] : components) { + referencedAdapters.insert(compAlias); + } + } + for (const auto& adapter : args.loraAdapters) { + if (referencedAdapters.find(adapter.alias) == referencedAdapters.end()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, + "NPU device: LoRA adapter '{}' is not referenced by any composite_lora_adapters entry. " + "On NPU all adapters are static and only composite aliases are routable.", + adapter.alias); + return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; + } + } + } + } + + return StatusCode::OK; +} + std::variant prepareImageGenPipelineArgs(const google::protobuf::Any& calculatorOptions, const std::string& graphPath) { mediapipe::ImageGenCalculatorOptions nodeOptions; if (!calculatorOptions.UnpackTo(&nodeOptions)) { @@ -125,9 +209,9 @@ std::variant prepareImageGenPipelineArgs(const goo auto fsModelsPath = std::filesystem::path(nodeOptions.models_path()); std::string pipelinePath; if (fsModelsPath.is_relative()) { - pipelinePath = (std::filesystem::path(graphPath) / fsModelsPath).string(); + pipelinePath = (std::filesystem::path(graphPath) / fsModelsPath).generic_string(); } else { - pipelinePath = fsModelsPath.string(); + pipelinePath = fsModelsPath.generic_string(); } ImageGenPipelineArgs args; args.modelsPath = pipelinePath; @@ -237,7 +321,9 @@ std::variant prepareImageGenPipelineArgs(const goo args.defaultResolution = std::get>(defaultResOptOrStatus); if (args.defaultResolution.value().first > args.maxResolution.first || args.defaultResolution.value().second > args.maxResolution.second) { - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Default resolution exceeds maximum allowed resolution: {} > {}", args.defaultResolution.value(), args.maxResolution); + SPDLOG_LOGGER_ERROR(modelmanager_logger, "Default resolution exceeds maximum allowed resolution: ({}x{}) > ({}x{})", + args.defaultResolution.value().first, args.defaultResolution.value().second, + args.maxResolution.first, args.maxResolution.second); return StatusCode::DEFAULT_EXCEEDS_MAXIMUM_ALLOWED_RESOLUTION; } // default resolution is not among the ones allowed @@ -245,7 +331,12 @@ std::variant prepareImageGenPipelineArgs(const goo auto& resolutions = args.staticReshapeSettings.value().resolution; auto it = std::find(resolutions.begin(), resolutions.end(), args.defaultResolution.value()); if (it == resolutions.end()) { - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Default resolution {} is not among the static resolutions: {}", args.defaultResolution.value(), resolutions); + std::string resStr; + for (const auto& r : resolutions) { + resStr += "(" + std::to_string(r.first) + "x" + std::to_string(r.second) + ") "; + } + SPDLOG_LOGGER_ERROR(modelmanager_logger, "Default resolution ({}x{}) is not among the static resolutions: {}", + args.defaultResolution.value().first, args.defaultResolution.value().second, resStr); return StatusCode::SHAPE_WRONG_FORMAT; } } @@ -260,6 +351,50 @@ std::variant prepareImageGenPipelineArgs(const goo args.maxNumImagesPerPrompt = nodeOptions.max_num_images_per_prompt(); args.defaultNumInferenceSteps = nodeOptions.default_num_inference_steps(); args.maxNumInferenceSteps = nodeOptions.max_num_inference_steps(); + + for (int i = 0; i < nodeOptions.lora_adapters_size(); ++i) { + const auto& loraEntry = nodeOptions.lora_adapters(i); + LoraAdapterInfo info; + info.alias = loraEntry.alias(); + auto fsLoraPath = std::filesystem::path(loraEntry.path()); + if (fsLoraPath.is_relative()) { + info.path = (std::filesystem::path(graphPath) / fsLoraPath).generic_string(); + } else { + info.path = fsLoraPath.generic_string(); + } + info.alpha = loraEntry.has_alpha() ? loraEntry.alpha() : DEFAULT_ALPHA; + switch (loraEntry.mode()) { + case ::mediapipe::DYNAMIC: + info.mode = LoraLoadMode::DYNAMIC; + break; + case ::mediapipe::STATIC: + info.mode = LoraLoadMode::STATIC; + break; + case ::mediapipe::FUSE: + info.mode = LoraLoadMode::FUSE; + break; + default: + info.mode = LoraLoadMode::DYNAMIC; + break; + } + args.loraAdapters.push_back(std::move(info)); + } + + for (int i = 0; i < nodeOptions.composite_lora_adapters_size(); ++i) { + const auto& compositeEntry = nodeOptions.composite_lora_adapters(i); + std::vector> components; + for (int j = 0; j < compositeEntry.components_size(); ++j) { + const auto& comp = compositeEntry.components(j); + components.emplace_back(comp.adapter_alias(), comp.alpha()); + } + args.compositeLoraAdapters.emplace(compositeEntry.alias(), std::move(components)); + } + + auto loraValidationStatus = validateLoraAdapterConfig(args, isNPU); + if (!loraValidationStatus.ok()) { + return loraValidationStatus; + } + return std::move(args); } } // namespace ovms diff --git a/src/image_gen/imagegenpipelineargs.hpp b/src/image_gen/imagegenpipelineargs.hpp index 294e8ac826..ff33a81063 100644 --- a/src/image_gen/imagegenpipelineargs.hpp +++ b/src/image_gen/imagegenpipelineargs.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -39,6 +40,21 @@ struct StaticReshapeSettingsArgs { guidanceScale(guidance) {} }; +enum class LoraLoadMode { + DYNAMIC = 0, // Apply/remove at inference time (hot-swap between requests) + STATIC = 1, // Compile with fixed alpha (NPU - no runtime switching) + FUSE = 2 // Permanently merge into base weights (always active, not selectable) +}; + +constexpr float DEFAULT_ALPHA = 1.0f; + +struct LoraAdapterInfo { + std::string alias; + std::string path; // absolute path to .safetensors file + float alpha = DEFAULT_ALPHA; + LoraLoadMode mode = LoraLoadMode::DYNAMIC; +}; + struct ImageGenPipelineArgs { std::string modelsPath; std::vector device; @@ -51,5 +67,9 @@ struct ImageGenPipelineArgs { uint64_t maxNumInferenceSteps; std::optional staticReshapeSettings; + std::vector loraAdapters; + // Maps a composite alias to its component (adapter alias, weight) pairs. + using CompositeLoraMap = std::unordered_map>>; + CompositeLoraMap compositeLoraAdapters; }; } // namespace ovms diff --git a/src/image_gen/imagegenutils.cpp b/src/image_gen/imagegenutils.cpp index 2235e9fa44..f72c7a1c41 100644 --- a/src/image_gen/imagegenutils.cpp +++ b/src/image_gen/imagegenutils.cpp @@ -317,7 +317,7 @@ absl::Status ensureAcceptableAndDefaultsSetRequestOptions(ov::AnyMap& requestOpt return absl::OkStatus(); } -std::variant getImageGenerationRequestOptions(const rapidjson::Document& parser, const ovms::ImageGenPipelineArgs& args) { +std::variant getImageGenerationRequestOptions(const rapidjson::Document& parser, const ovms::ImageGenPipelineArgs& args, bool hasDynamicAdapters) { // NO -not handled yet // OpenAI parameters // https://platform.openai.com/docs/api-reference/images/create 15/05/2025 @@ -414,12 +414,19 @@ std::variant getImageGenerationRequestOptions(const ra "size", "height", "width", "n", "num_images_per_prompt", "response_format", // allowed, however only b64_json is supported - "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model"}; + "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model", + "lora_alphas"}; // per-request LoRA alpha overrides for (auto it = parser.MemberBegin(); it != parser.MemberEnd(); ++it) { if (acceptedFields.find(it->name.GetString()) == acceptedFields.end()) { return absl::InvalidArgumentError(absl::StrCat("Unhandled parameter: ", it->name.GetString())); } } + // Reject lora_alphas when no dynamic adapters are available (STATIC/FUSE modes) + auto loraAlphasOverride = parseLoraAlphasOverride(parser); + auto loraAlphaStatus = validateLoraAlphasAllowed(hasDynamicAdapters, loraAlphasOverride); + if (!loraAlphaStatus.ok()) { + return loraAlphaStatus; + } auto status = ensureAcceptableAndDefaultsSetRequestOptions(requestOptions, args); if (!status.ok()) { return status; @@ -434,7 +441,7 @@ std::variant getImageGenerationRequestOptions(const ra return std::move(requestOptions); } -std::variant getImageEditRequestOptions(const ovms::MultiPartParser& parser, const ovms::ImageGenPipelineArgs& args) { +std::variant getImageEditRequestOptions(const ovms::MultiPartParser& parser, const ovms::ImageGenPipelineArgs& args, bool hasDynamicAdapters) { // NO -not handled yet // OpenAI parameters // https://platform.openai.com/docs/api-reference/images/createEdit 20/05/2025 @@ -532,13 +539,19 @@ std::variant getImageEditRequestOptions(const ovms::Mu "size", "height", "width", "n", "num_images_per_prompt", "response_format", // allowed, however only b64_json is supported - "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model"}; + "num_inference_steps", "rng_seed", "strength", "guidance_scale", "max_sequence_length", "model", + "lora_alphas"}; // per-request LoRA alpha overrides auto fieldNames = parser.getAllFieldNames(); for (const auto& fieldName : fieldNames) { if (acceptedFields.find(fieldName) == acceptedFields.end()) { return absl::InvalidArgumentError(absl::StrCat("Unhandled parameter: ", fieldName)); } } + // Reject lora_alphas when no dynamic adapters are available (STATIC/FUSE modes) + auto loraAlphaStatus = validateLoraAlphasAllowed(hasDynamicAdapters, parser); + if (!loraAlphaStatus.ok()) { + return loraAlphaStatus; + } auto status = ensureAcceptableAndDefaultsSetRequestOptions(requestOptions, args); if (!status.ok()) { return status; @@ -595,4 +608,35 @@ std::unique_ptr generateJSONResponseFromB64Images(const std::vector << "]}" << std::endl; return std::make_unique(jsonStream.str()); } + +std::unordered_map parseLoraAlphasOverride(const rapidjson::Document& doc) { + std::unordered_map result; + auto it = doc.FindMember("lora_alphas"); + if (it != doc.MemberEnd() && it->value.IsObject()) { + for (auto member = it->value.MemberBegin(); member != it->value.MemberEnd(); ++member) { + if (member->value.IsNumber()) { + result[member->name.GetString()] = member->value.GetFloat(); + } + } + } + return result; +} + +absl::Status validateLoraAlphasAllowed(bool hasDynamicAdapters, const std::unordered_map& loraAlphasOverride) { + if (!hasDynamicAdapters && !loraAlphasOverride.empty()) { + return absl::InvalidArgumentError( + "lora_alphas is not supported when no dynamic LoRA adapters are available. " + "Alpha values cannot be overridden for STATIC (NPU) or FUSE mode adapters."); + } + return absl::OkStatus(); +} + +absl::Status validateLoraAlphasAllowed(bool hasDynamicAdapters, const ovms::MultiPartParser& parser) { + if (!hasDynamicAdapters && !parser.getFieldByName("lora_alphas").empty()) { + return absl::InvalidArgumentError( + "lora_alphas is not supported when no dynamic LoRA adapters are available. " + "Alpha values cannot be overridden for STATIC (NPU) or FUSE mode adapters."); + } + return absl::OkStatus(); +} } // namespace ovms diff --git a/src/image_gen/imagegenutils.hpp b/src/image_gen/imagegenutils.hpp index adb9cf7005..6ffdebfde3 100644 --- a/src/image_gen/imagegenutils.hpp +++ b/src/image_gen/imagegenutils.hpp @@ -16,6 +16,7 @@ //***************************************************************************** #include #include +#include #include #include #include @@ -65,10 +66,17 @@ std::variant> getSizetFromPayload(const ovms std::variant> getFloatFromPayload(const rapidjson::Document& doc, const std::string& keyName); std::variant> getFloatFromPayload(const ovms::MultiPartParser& payload, const std::string& keyName); -std::variant getImageGenerationRequestOptions(const rapidjson::Document& doc, const ImageGenPipelineArgs& args); -std::variant getImageEditRequestOptions(const ovms::MultiPartParser& payload, const ImageGenPipelineArgs& args); +std::variant getImageGenerationRequestOptions(const rapidjson::Document& doc, const ImageGenPipelineArgs& args, bool hasDynamicAdapters = true); +std::variant getImageEditRequestOptions(const ovms::MultiPartParser& payload, const ImageGenPipelineArgs& args, bool hasDynamicAdapters = true); std::unique_ptr generateJSONResponseFromB64Images(const std::vector& base64Images); std::variant> generateJSONResponseFromOvTensor(const ov::Tensor& tensor); + +std::unordered_map parseLoraAlphasOverride(const rapidjson::Document& doc); + +// Returns an error if lora_alphas override is present but no dynamic adapters exist. +// lora_alphas is only valid when adapters use DYNAMIC mode (runtime alpha switching). +absl::Status validateLoraAlphasAllowed(bool hasDynamicAdapters, const std::unordered_map& loraAlphasOverride); +absl::Status validateLoraAlphasAllowed(bool hasDynamicAdapters, const ovms::MultiPartParser& parser); } // namespace ovms diff --git a/src/image_gen/pipelines.cpp b/src/image_gen/pipelines.cpp index 65071fef60..c0ee593d85 100644 --- a/src/image_gen/pipelines.cpp +++ b/src/image_gen/pipelines.cpp @@ -15,6 +15,7 @@ //***************************************************************************** #include "pipelines.hpp" +#include #include #include @@ -22,6 +23,7 @@ #include #include "src/logging.hpp" +#include "src/stringutils.hpp" namespace ovms { @@ -30,7 +32,8 @@ namespace ovms { template static void reshapeAndCompile(PipelineT& pipeline, const ImageGenPipelineArgs& args, - const std::vector& device) { + const std::vector& device, + const ov::AnyMap& properties) { if (args.staticReshapeSettings.has_value() && args.staticReshapeSettings.value().resolution.size() == 1) { auto numImagesPerPrompt = args.staticReshapeSettings.value().numImagesPerPrompt.value_or(ov::genai::ImageGenerationConfig().num_images_per_prompt); auto guidanceScale = args.staticReshapeSettings.value().guidanceScale.value_or(ov::genai::ImageGenerationConfig().guidance_scale); @@ -47,10 +50,10 @@ static void reshapeAndCompile(PipelineT& pipeline, if (device.size() == 1) { SPDLOG_DEBUG("Image Generation Pipeline compiling to device: {}", device[0]); - pipeline.compile(device[0], args.pluginConfig); + pipeline.compile(device[0], properties); } else { SPDLOG_DEBUG("Image Generation Pipeline compiling to devices: text_encode={} denoise={} vae={}", device[0], device[1], device[2]); - pipeline.compile(device[0], device[1], device[2], args.pluginConfig); + pipeline.compile(device[0], device[1], device[2], properties); } } @@ -65,6 +68,90 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a SPDLOG_DEBUG("Image Generation Pipelines weights loading from: {}", args.modelsPath); + // --- Load LoRA adapters before pipeline compilation --- + // Adapters must be registered at compile time so that the AdapterController + // is initialized and can apply/disable them at inference time. + // FUSE adapters are loaded separately and use MODE_FUSE to permanently merge into weights. + std::vector> fuseAdapters; + for (const auto& loraInfo : args.loraAdapters) { + SPDLOG_INFO("Loading LoRA adapter: {} from: {} (mode: {})", loraInfo.alias, loraInfo.path, + loraInfo.mode == LoraLoadMode::FUSE ? "FUSE" : (loraInfo.mode == LoraLoadMode::STATIC ? "STATIC" : "DYNAMIC")); + try { + auto adapter = ov::genai::Adapter(loraInfo.path); + if (loraInfo.mode == LoraLoadMode::FUSE) { + fuseAdapters.emplace_back(std::move(adapter), loraInfo.alpha); + } else { + loraAdapters.emplace(loraInfo.alias, std::move(adapter)); + } + SPDLOG_INFO("LoRA adapter loaded: {}", loraInfo.alias); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to load LoRA adapter '" + loraInfo.alias + "' from " + loraInfo.path + ": " + e.what()); + } + } + + // Build compile-time adapter properties so the pipeline's AdapterController + // knows about all adapters. At generate time we select which to activate. + ov::AnyMap compileProperties = args.pluginConfig; + + // FUSE adapters: permanently merged into base weights using MODE_FUSE. + // These are always active, not switchable, and invisible to request routing. + if (!fuseAdapters.empty()) { + ov::genai::AdapterConfig fuseConfig; + for (const auto& [adapter, alpha] : fuseAdapters) { + fuseConfig.add(adapter, alpha); + } + fuseConfig.set_mode(ov::genai::AdapterConfig::MODE_FUSE); + compileProperties.insert(ov::genai::adapters(fuseConfig)); + SPDLOG_INFO("Fused {} LoRA adapter(s) into base model weights (MODE_FUSE)", fuseAdapters.size()); + } + + // DYNAMIC/STATIC adapters: registered for runtime switching. + if (!loraAdapters.empty()) { + ov::genai::AdapterConfig adapterConfig; + for (const auto& [alias, adapter] : loraAdapters) { + // Use the configured alpha from args for each adapter + float alpha = DEFAULT_ALPHA; + for (const auto& info : args.loraAdapters) { + if (info.alias == alias) { + alpha = info.alpha; + break; + } + } + adapterConfig.add(adapter, alpha); + } + // NPU requires MODE_STATIC — adapters are compiled with fixed alpha values. + // Runtime switching is not possible; all adapters remain active at their compile-time alpha. + bool hasNPU = std::find(device.begin(), device.end(), "NPU") != device.end(); + if (hasNPU) { + adapterConfig.set_mode(ov::genai::AdapterConfig::MODE_STATIC); + npuLoraStaticMode = true; + SPDLOG_INFO("NPU detected: LoRA adapters compiled with MODE_STATIC (no runtime switching)"); + } else { + // Check if any adapter explicitly requests STATIC mode + bool anyStatic = std::any_of(args.loraAdapters.begin(), args.loraAdapters.end(), + [](const LoraAdapterInfo& info) { return info.mode == LoraLoadMode::STATIC; }); + if (anyStatic) { + adapterConfig.set_mode(ov::genai::AdapterConfig::MODE_STATIC); + npuLoraStaticMode = true; + SPDLOG_INFO("STATIC mode requested: LoRA adapters compiled with MODE_STATIC"); + } + } + // FUSE adapters are permanently merged into base weights at compile time (irreversible). + // DYNAMIC/STATIC adapters are registered separately for runtime switching. + // GenAI handles fuse internally during compile — here we overwrite the adapter config + // property with the DYNAMIC/STATIC config since fuse is already applied to weights. + if (compileProperties.count(ov::genai::adapters.name())) { + SPDLOG_INFO("Both FUSE and DYNAMIC/STATIC adapters present — overwriting adapter config (FUSE already applied to weights)"); + } + compileProperties.insert_or_assign(ov::genai::adapters.name(), ov::genai::adapters(adapterConfig).second); + } + + // Populate composite LoRA map from args + compositeLoraAdapters = args.compositeLoraAdapters; + for (const auto& [alias, components] : compositeLoraAdapters) { + SPDLOG_INFO("Registered composite LoRA adapter: {} with {} components", alias, components.size()); + } + // Pipeline construction strategy: // Preferred chain (weight-sharing, single model load): // INP(disk) → reshape+compile → I2I(INP) → T2I(I2I) @@ -78,7 +165,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a // --- Step 1: InpaintingPipeline from disk --- try { inpaintingPipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*inpaintingPipeline, args, device); + reshapeAndCompile(*inpaintingPipeline, args, device, compileProperties); SPDLOG_DEBUG("InpaintingPipeline created from disk"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create InpaintingPipeline from disk: {}", e.what()); @@ -97,7 +184,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a if (!image2ImagePipeline) { try { image2ImagePipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*image2ImagePipeline, args, device); + reshapeAndCompile(*image2ImagePipeline, args, device, compileProperties); SPDLOG_DEBUG("Image2ImagePipeline created from disk (fallback)"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create Image2ImagePipeline from disk: {}", e.what()); @@ -125,7 +212,7 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a if (!text2ImagePipeline) { try { text2ImagePipeline = std::make_unique(args.modelsPath); - reshapeAndCompile(*text2ImagePipeline, args, device); + reshapeAndCompile(*text2ImagePipeline, args, device, compileProperties); SPDLOG_DEBUG("Text2ImagePipeline created from disk (fallback)"); } catch (const std::exception& e) { SPDLOG_WARN("Failed to create Text2ImagePipeline from disk: {}", e.what()); @@ -144,9 +231,10 @@ ImageGenerationPipelines::ImageGenerationPipelines(const ImageGenPipelineArgs& a inpaintingQueue = std::make_unique>(1); } - SPDLOG_INFO("Image Generation Pipelines ready — T2I: {} | I2I: {} | INP: {}", + SPDLOG_INFO("Image Generation Pipelines ready — T2I: {} | I2I: {} | INP: {} | LoRAs: {}", text2ImagePipeline ? "OK" : "N/A", image2ImagePipeline ? "OK" : "N/A", - inpaintingPipeline ? "OK" : "N/A"); + inpaintingPipeline ? "OK" : "N/A", + loraAdapters.size()); } } // namespace ovms diff --git a/src/image_gen/pipelines.hpp b/src/image_gen/pipelines.hpp index cda14396a7..165428a645 100644 --- a/src/image_gen/pipelines.hpp +++ b/src/image_gen/pipelines.hpp @@ -17,10 +17,14 @@ #include #include +#include +#include +#include #include #include #include +#include #include "imagegenpipelineargs.hpp" #include "src/queue.hpp" @@ -28,19 +32,19 @@ namespace ovms { // RAII guard that acquires a slot from a Queue(1) on construction -// and returns it on destruction, serializing concurrent inpainting requests. -class InpaintingQueueGuard { +// and returns it on destruction, serializing concurrent pipeline access. +class PipelineSlotGuard { public: - // Blocks until an inpainting slot becomes available. - explicit InpaintingQueueGuard(Queue& queue) : + // Blocks until a pipeline slot becomes available. + explicit PipelineSlotGuard(Queue& queue) : queue_(queue), streamId_(queue_.getIdleStream().get()) {} - ~InpaintingQueueGuard() { + ~PipelineSlotGuard() { queue_.returnStream(streamId_); } - InpaintingQueueGuard(const InpaintingQueueGuard&) = delete; - InpaintingQueueGuard& operator=(const InpaintingQueueGuard&) = delete; + PipelineSlotGuard(const PipelineSlotGuard&) = delete; + PipelineSlotGuard& operator=(const PipelineSlotGuard&) = delete; private: Queue& queue_; @@ -51,8 +55,15 @@ struct ImageGenerationPipelines { std::unique_ptr image2ImagePipeline; std::unique_ptr text2ImagePipeline; std::unique_ptr inpaintingPipeline; + std::unordered_map loraAdapters; // alias -> loaded adapter + // composite alias -> [(component adapter alias, weight)] + std::unordered_map>> compositeLoraAdapters; ImageGenPipelineArgs args; + // When true, LoRA adapters were compiled with MODE_STATIC (NPU). + // Runtime adapter switching is not possible — adapters are always active. + bool npuLoraStaticMode = false; + // Serializes concurrent inpainting requests (InpaintingPipeline lacks clone()). // Queue size = 1: only one inpainting inference runs at a time. std::unique_ptr> inpaintingQueue; diff --git a/src/mediapipe_internal/graph_side_packets.hpp b/src/mediapipe_internal/graph_side_packets.hpp index b9cdb147c9..6804974c81 100644 --- a/src/mediapipe_internal/graph_side_packets.hpp +++ b/src/mediapipe_internal/graph_side_packets.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace ovms { class PythonNodeResources; @@ -43,6 +44,8 @@ struct GraphSidePackets { RerankServableMap rerankServableMap; SttServableMap sttServableMap; TtsServableMap ttsServableMap; + std::vector loraAliases; + bool hideBaseModelInRouting = false; void clear() { pythonNodeResourcesMap.clear(); genAiServableMap.clear(); @@ -51,6 +54,8 @@ struct GraphSidePackets { rerankServableMap.clear(); sttServableMap.clear(); ttsServableMap.clear(); + loraAliases.clear(); + hideBaseModelInRouting = false; } bool empty() { return (pythonNodeResourcesMap.empty() && diff --git a/src/mediapipe_internal/mediapipefactory.cpp b/src/mediapipe_internal/mediapipefactory.cpp index 21dcdb89da..f47c8c3f03 100644 --- a/src/mediapipe_internal/mediapipefactory.cpp +++ b/src/mediapipe_internal/mediapipefactory.cpp @@ -61,25 +61,43 @@ Status MediapipeFactory::createDefinition(const std::string& pipelineName, SPDLOG_LOGGER_ERROR(modelmanager_logger, "Mediapipe graph definition: {} is already created", pipelineName); return StatusCode::PIPELINE_DEFINITION_ALREADY_EXIST; } - std::shared_ptr graphDefinition = std::make_shared(pipelineName, config, metrics.getMetricRegistry(), &metrics.getMetricConfig(), pythonBackend); + std::shared_ptr graphDefinition = std::make_shared( + pipelineName, config, metrics.getMetricRegistry(), &metrics.getMetricConfig(), pythonBackend); auto stat = graphDefinition->validate(checker); if (stat.getCode() == StatusCode::MEDIAPIPE_GRAPH_NAME_OCCUPIED) { return stat; } std::unique_lock lock(definitionsMtx); definitions.insert({pipelineName, std::move(graphDefinition)}); + // Register LoRA aliases discovered during validation (image gen graphs) + auto* def = definitions[pipelineName].get(); + for (const auto& alias : def->getLoraAliases()) { + loraAliases[alias] = pipelineName; + SPDLOG_LOGGER_INFO(modelmanager_logger, "Registered LoRA alias: {} -> {}", alias, pipelineName); + } return stat; } bool MediapipeFactory::definitionExists(const std::string& name) const { std::shared_lock lock(definitionsMtx); - return this->definitions.find(name) != this->definitions.end(); + if (this->definitions.find(name) != this->definitions.end()) { + return true; + } + return loraAliases.find(name) != loraAliases.end(); } MediapipeGraphDefinition* MediapipeFactory::findDefinitionByName(const std::string& name) const { std::shared_lock lock(definitionsMtx); auto it = definitions.find(name); if (it == std::end(definitions)) { + // Check LoRA aliases + auto aliasIt = loraAliases.find(name); + if (aliasIt != loraAliases.end()) { + it = definitions.find(aliasIt->second); + if (it != std::end(definitions)) { + return it->second.get(); + } + } return nullptr; } else { return it->second.get(); @@ -95,13 +113,29 @@ Status MediapipeFactory::reloadDefinition(const std::string& name, return StatusCode::INTERNAL_ERROR; } SPDLOG_LOGGER_INFO(modelmanager_logger, "Reloading mediapipe graph: {}", name); - return mgd->reload(checker, config); + clearLoraAliases(name); + auto status = mgd->reload(checker, config); + if (status.ok()) { + std::unique_lock lock(definitionsMtx); + for (const auto& alias : mgd->getLoraAliases()) { + loraAliases[alias] = name; + SPDLOG_LOGGER_INFO(modelmanager_logger, "Registered LoRA alias: {} -> {}", alias, name); + } + } + return status; } Status MediapipeFactory::create(std::unique_ptr& pipeline, const std::string& name) const { std::shared_lock lock(definitionsMtx); auto it = definitions.find(name); + if (it == definitions.end()) { + // Check LoRA aliases + auto aliasIt = loraAliases.find(name); + if (aliasIt != loraAliases.end()) { + it = definitions.find(aliasIt->second); + } + } if (it == definitions.end()) { SPDLOG_LOGGER_DEBUG(dag_executor_logger, "Mediapipe with requested name: {} does not exist", name); return StatusCode::MEDIAPIPE_DEFINITION_NAME_MISSING; @@ -138,12 +172,52 @@ const std::vector MediapipeFactory::getNamesOfAvailableMediapipePip std::vector names; std::shared_lock lock(definitionsMtx); for (auto& [name, definition] : definitions) { - if (definition->getStatus().isAvailable()) { + if (definition->getStatus().isAvailable() && !definition->shouldHideBaseModelInRouting()) { names.push_back(definition->getName()); } } + // Add LoRA aliases that point to available definitions + for (const auto& [alias, graphName] : loraAliases) { + auto it = definitions.find(graphName); + if (it != definitions.end() && it->second->getStatus().isAvailable()) { + names.push_back(alias); + } + } return names; } +void MediapipeFactory::registerLoraAlias(const std::string& alias, const std::string& graphName) { + std::unique_lock lock(definitionsMtx); + loraAliases[alias] = graphName; + SPDLOG_LOGGER_INFO(modelmanager_logger, "Registered LoRA alias: {} -> {}", alias, graphName); +} + +void MediapipeFactory::clearLoraAliases(const std::string& graphName) { + std::unique_lock lock(definitionsMtx); + for (auto it = loraAliases.begin(); it != loraAliases.end();) { + if (it->second == graphName) { + SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Removing LoRA alias: {} -> {}", it->first, graphName); + it = loraAliases.erase(it); + } else { + ++it; + } + } +} + +bool MediapipeFactory::aliasesConflictExcluding(const std::vector& aliases, const std::string& ownGraphName) const { + std::shared_lock lock(definitionsMtx); + for (const auto& alias : aliases) { + auto defIt = definitions.find(alias); + if (defIt != definitions.end() && defIt->first != ownGraphName) { + return true; + } + auto aliasIt = loraAliases.find(alias); + if (aliasIt != loraAliases.end() && aliasIt->second != ownGraphName) { + return true; + } + } + return false; +} + MediapipeFactory::~MediapipeFactory() = default; } // namespace ovms diff --git a/src/mediapipe_internal/mediapipefactory.hpp b/src/mediapipe_internal/mediapipefactory.hpp index 0d03fcc7a4..a9ac8ae9b0 100644 --- a/src/mediapipe_internal/mediapipefactory.hpp +++ b/src/mediapipe_internal/mediapipefactory.hpp @@ -37,6 +37,7 @@ class PythonBackend; class MediapipeFactory { std::map> definitions; + std::map loraAliases; // alias -> real graph definition name mutable std::shared_mutex definitionsMtx; PythonBackend* pythonBackend{nullptr}; @@ -55,6 +56,9 @@ class MediapipeFactory { const std::string& name) const; MediapipeGraphDefinition* findDefinitionByName(const std::string& name) const; + void registerLoraAlias(const std::string& alias, const std::string& graphName); + void clearLoraAliases(const std::string& graphName); + bool aliasesConflictExcluding(const std::vector& aliases, const std::string& ownGraphName) const; Status reloadDefinition(const std::string& pipelineName, const MediapipeGraphConfig& config, const ServableNameChecker& checker); diff --git a/src/mediapipe_internal/mediapipegraphdefinition.cpp b/src/mediapipe_internal/mediapipegraphdefinition.cpp index dd96a6ed0a..2bc062d8e0 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.cpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.cpp @@ -168,6 +168,11 @@ Status MediapipeGraphDefinition::validate(const ServableNameChecker& checker) { return status; } + if (!this->loraAliases.empty() && checker.aliasesConflict(this->loraAliases, getName())) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "LoRA alias in graph '{}' conflicts with an existing servable", getName()); + return StatusCode::MEDIAPIPE_GRAPH_NAME_OCCUPIED; + } + lock.unlock(); notifier.passed = true; SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Finished validation of mediapipe: {}", getName()); @@ -377,6 +382,9 @@ Status MediapipeGraphDefinition::initializeNodes() { } } } + // Register LoRA aliases for routing from initialized image gen pipelines + this->loraAliases = sidePacketMaps.loraAliases; + this->hideBaseModelInRouting = sidePacketMaps.hideBaseModelInRouting; success = true; return StatusCode::OK; } diff --git a/src/mediapipe_internal/mediapipegraphdefinition.hpp b/src/mediapipe_internal/mediapipegraphdefinition.hpp index 32b01a9266..e717ded5bf 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.hpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.hpp @@ -60,6 +60,8 @@ class MediapipeGraphDefinition : public SingleVersionServableDefinition { const PipelineDefinitionStatus& getStatus() const override { return this->status; } + const std::vector& getLoraAliases() const { return loraAliases; } + bool shouldHideBaseModelInRouting() const { return hideBaseModelInRouting; } const PipelineDefinitionStateCode getStateCode() const { return status.getStateCode(); } bool isAvailable() const override { return status.isAvailable(); } @@ -133,6 +135,9 @@ class MediapipeGraphDefinition : public SingleVersionServableDefinition { std::vector outputNames; std::vector inputSidePacketNames; + std::vector loraAliases; + bool hideBaseModelInRouting = false; + PythonBackend* pythonBackend; std::unique_ptr reporter; diff --git a/src/modelmanager.cpp b/src/modelmanager.cpp index ad1d9bc94a..faf50b8d27 100644 --- a/src/modelmanager.cpp +++ b/src/modelmanager.cpp @@ -1657,6 +1657,20 @@ bool ModelManager::servableExists(const std::string& name, ServableQueryType che return false; } +bool ModelManager::aliasesConflict(const std::vector& aliases, const std::string& ownGraphName) const { + for (const auto& alias : aliases) { + if (servableExists(alias, ServableQueryType::Model | ServableQueryType::Pipeline)) { + return true; + } + } +#if (MEDIAPIPE_DISABLE == 0) + if (mediapipeFactory->aliasesConflictExcluding(aliases, ownGraphName)) { + return true; + } +#endif + return false; +} + const PipelineFactory& ModelManager::getPipelineFactory() const { return *pipelineFactory; } diff --git a/src/modelmanager.hpp b/src/modelmanager.hpp index 1dc9fa86ea..843725eba1 100644 --- a/src/modelmanager.hpp +++ b/src/modelmanager.hpp @@ -486,6 +486,7 @@ class ModelManager : public ServableNameChecker, public MetricProvider, public M void cleanupResources() override; bool servableExists(const std::string& name, ServableQueryType check = ServableQueryType::All) const override; + bool aliasesConflict(const std::vector& aliases, const std::string& ownGraphName) const override; ServableDefinition* findServableDefinition(const std::string& name) const; diff --git a/src/pull_module/BUILD b/src/pull_module/BUILD index cf8fd5eeba..1b6c896f94 100644 --- a/src/pull_module/BUILD +++ b/src/pull_module/BUILD @@ -55,14 +55,26 @@ ovms_cc_library( ], visibility = ["//visibility:public"], ) +ovms_cc_library( + name = "curl_downloader", + srcs = ["curl_downloader.cpp"], + hdrs = ["curl_downloader.hpp"], + deps = [ + "//third_party:curl", + "@ovms//src:libovmslogging", + "@ovms//src:libovmsstatus", + "@ovms//src:libovms_version", + ], + visibility = ["//visibility:public"], +) + ovms_cc_library( name = "gguf_downloader", srcs = ["gguf_downloader.cpp"], hdrs = ["gguf_downloader.hpp"], deps = [ + ":curl_downloader", ":model_downloader", - "//third_party:curl", - "@nlohmann_json//:json", "@ovms//src:libovmslogging", "@ovms//src:libovmsstatus", "@ovms//src:libovmsstring_utils", @@ -97,6 +109,7 @@ ovms_cc_library( srcs = ["hf_pull_model_module.cpp"], hdrs = ["hf_pull_model_module.hpp"], deps = [ + ":curl_downloader", ":libgit2", "gguf_downloader", ":optimum_export", @@ -106,6 +119,9 @@ ovms_cc_library( "@ovms//src:libovmslogging", "@ovms//src:libovms_server_settings", "@ovms//src:libovms_module", + "@ovms//src:libovms_version", + "//third_party:curl", + "@nlohmann_json//:json", ], visibility = ["//visibility:public"], ) diff --git a/src/pull_module/curl_downloader.cpp b/src/pull_module/curl_downloader.cpp new file mode 100644 index 0000000000..5c0243550b --- /dev/null +++ b/src/pull_module/curl_downloader.cpp @@ -0,0 +1,252 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "curl_downloader.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include "src/logging.hpp" +#include "src/status.hpp" +#include "src/version.hpp" + +namespace ovms { + +static const char* sizeUnits[] = {"B", "KB", "MB", "GB", "TB", NULL}; + +static void print_download_speed_info(size_t received_size, size_t elapsed_time) { + double recv_len = (double)received_size; + uint64_t elapsed = (uint64_t)elapsed_time; + double rate; + rate = elapsed ? recv_len / elapsed : received_size; + + size_t rate_unit_idx = 0; + while (rate > 1000 && sizeUnits[rate_unit_idx + 1]) { + rate /= 1000.0; + rate_unit_idx++; + } + printf(" [%.2f %s/s] ", rate, sizeUnits[rate_unit_idx]); +} + +static void print_progress(size_t count, size_t max, bool first_run, size_t elapsed_time) { + float progress = (float)count / max; + if (!first_run && progress < 0.01 && count > 0) + return; + + const int bar_width = 50; + int bar_length = progress * bar_width; + + printf("\rProgress: ["); + int i; + for (i = 0; i < bar_length; ++i) { + printf("#"); + } + for (i = bar_length; i < bar_width; ++i) { + printf(" "); + } + size_t totalSizeUnitId = 0; + double totalSize = max; + while (totalSize > 1000 && sizeUnits[totalSizeUnitId + 1]) { + totalSize /= 1000.0; + totalSizeUnitId++; + } + printf("] %.2f%% of %.2f %s", progress * 100, totalSize, sizeUnits[totalSizeUnitId]); + print_download_speed_info(count, elapsed_time); + if (progress == 1.0) + printf("\n"); + fflush(stdout); +} + +struct CurlDownloadFile { + const char* filename; + FILE* stream; + CurlDownloadFile() = delete; + CurlDownloadFile(const CurlDownloadFile&) = delete; + CurlDownloadFile& operator=(const CurlDownloadFile&) = delete; + CurlDownloadFile(const char* fname, FILE* str) : + filename(fname), + stream(str) {} + ~CurlDownloadFile() { + if (stream) { + fclose(stream); + } + if (!success) { + std::filesystem::remove(filename); + } + } + bool success = false; +}; + +static size_t file_write_callback(void* buffer, size_t size, size_t nmemb, void* stream) { + CurlDownloadFile* out = static_cast(stream); + if (!out->stream) { + out->stream = fopen(out->filename, "wb"); + if (!out->stream) { + fprintf(stderr, "failure, cannot open file to write: %s\n", + out->filename); + return 0; + } + } + return fwrite(buffer, size, nmemb, out->stream); +} + +#define CHECK_CURL_CALL(call) \ + do { \ + CURLcode curlCode = call; \ + if (curlCode != CURLE_OK) { \ + SPDLOG_ERROR("curl error: {}. Error code: {}", curl_easy_strerror(curlCode), (int)curlCode); \ + return StatusCode::INTERNAL_ERROR; \ + } \ + } while (0) + +struct ProgressData { + time_t started_download; + time_t last_print_time; + bool fullDownloadPrinted = false; +}; + +static int progress_callback(void* clientp, + curl_off_t dltotal, + curl_off_t dlnow, + curl_off_t ultotal, + curl_off_t ulnow) { + ProgressData* pcs = reinterpret_cast(clientp); + if (dlnow == 0) { + pcs->started_download = time(NULL); + pcs->last_print_time = time(NULL); + } + time_t currentTime = time(NULL); + bool shouldPrintDueToTime = (currentTime - pcs->last_print_time >= 1); + if ((dltotal == dlnow) && dltotal < 10000) { + return 0; + } + if (pcs->fullDownloadPrinted) { + return 0; + } + if (!shouldPrintDueToTime && (dltotal != dlnow)) { + return 0; + } + pcs->fullDownloadPrinted = (dltotal == dlnow); + pcs->last_print_time = currentTime; + print_progress(dlnow, dltotal, (dlnow == 0), currentTime - pcs->started_download); + std::cout.flush(); + return 0; +} + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath) { + return downloadFileWithCurl(url, filePath, ""); +} + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath, const std::string& authTokenHF) { + std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); + + CURL* curl = nullptr; + CHECK_CURL_CALL(curl_global_init(CURL_GLOBAL_DEFAULT)); + auto globalCurlGuard = std::unique_ptr( + nullptr, [](void*) { curl_global_cleanup(); }); + curl = curl_easy_init(); + if (!curl) { + SPDLOG_ERROR("Failed to initialize cURL."); + return StatusCode::INTERNAL_ERROR; + } + auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_URL, url.c_str())); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, file_write_callback)); + CurlDownloadFile downloadFile{filePath.c_str(), NULL}; + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEDATA, &downloadFile)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str())); + struct curl_slist* headers = nullptr; + std::string authHeader; + if (!authTokenHF.empty()) { + authHeader = "Authorization: Bearer " + authTokenHF; + headers = curl_slist_append(headers, authHeader.c_str()); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers)); + } + auto headersGuard = std::unique_ptr(headers, curl_slist_free_all); + ProgressData progressData; + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &progressData)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USE_SSL, CURLUSESSL_ALL)); + CHECK_CURL_CALL(curl_easy_perform(curl)); + int32_t http_code = 0; + CHECK_CURL_CALL(curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code)); + SPDLOG_TRACE("HTTP response code: {}", http_code); + if (http_code != 200) { + SPDLOG_ERROR("Failed to download file from URL: {} HTTP response code: {}", url, http_code); + return StatusCode::PATH_INVALID; + } + downloadFile.success = true; + return StatusCode::OK; +} + +static size_t string_write_callback(void* buffer, size_t size, size_t nmemb, void* userData) { + auto* body = static_cast(userData); + body->append(static_cast(buffer), size * nmemb); + return size * nmemb; +} + +Status fetchUrlToString(const std::string& url, const std::string& authToken, std::string& responseBody) { + std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); + + CURL* curl = nullptr; + CHECK_CURL_CALL(curl_global_init(CURL_GLOBAL_DEFAULT)); + auto globalCurlGuard = std::unique_ptr( + nullptr, [](void*) { curl_global_cleanup(); }); + curl = curl_easy_init(); + if (!curl) { + SPDLOG_ERROR("Failed to initialize cURL."); + return StatusCode::INTERNAL_ERROR; + } + auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_URL, url.c_str())); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, string_write_callback)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBody)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str())); + struct curl_slist* headers = nullptr; + std::string authHeader; + if (!authToken.empty()) { + authHeader = "Authorization: Bearer " + authToken; + headers = curl_slist_append(headers, authHeader.c_str()); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers)); + } + auto headersGuard = std::unique_ptr(headers, curl_slist_free_all); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_TIMEOUT, 60L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L)); + CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USE_SSL, CURLUSESSL_ALL)); + CHECK_CURL_CALL(curl_easy_perform(curl)); + int32_t httpCode = 0; + CHECK_CURL_CALL(curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode)); + if (httpCode != 200) { + SPDLOG_ERROR("HTTP request to {} failed with code: {}", url, httpCode); + return StatusCode::PATH_INVALID; + } + return StatusCode::OK; +} + +#undef CHECK_CURL_CALL + +} // namespace ovms diff --git a/src/pull_module/curl_downloader.hpp b/src/pull_module/curl_downloader.hpp new file mode 100644 index 0000000000..12a9ab39c5 --- /dev/null +++ b/src/pull_module/curl_downloader.hpp @@ -0,0 +1,26 @@ +#pragma once +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include + +namespace ovms { +class Status; + +Status downloadFileWithCurl(const std::string& url, const std::string& filePath); +Status downloadFileWithCurl(const std::string& url, const std::string& filePath, const std::string& authTokenHF); +Status fetchUrlToString(const std::string& url, const std::string& authToken, std::string& responseBody); + +} // namespace ovms diff --git a/src/pull_module/gguf_downloader.cpp b/src/pull_module/gguf_downloader.cpp index dd951e6147..0e9d3d19ed 100644 --- a/src/pull_module/gguf_downloader.cpp +++ b/src/pull_module/gguf_downloader.cpp @@ -15,23 +15,18 @@ //***************************************************************************** #include "gguf_downloader.hpp" -#include -#include #include +#include +#include #include -#include -#include - #include "../capi_frontend/server_settings.hpp" #include "src/filesystem/filesystem.hpp" #include "src/filesystem/localfilesystem.hpp" #include "../logging.hpp" -#include "../stringutils.hpp" #include "../status.hpp" -#include "../version.hpp" - -#include +#include "../stringutils.hpp" +#include "curl_downloader.hpp" namespace ovms { @@ -128,177 +123,6 @@ Status GGUFDownloader::downloadModel() { return StatusCode::OK; } -static const char* sizeUnits[] = {"B", "KB", "MB", "GB", "TB", NULL}; -static void print_download_speed_info(size_t received_size, size_t elapsed_time) { - double recv_len = (double)received_size; - uint64_t elapsed = (uint64_t)elapsed_time; - double rate; - rate = elapsed ? recv_len / elapsed : received_size; - - size_t rate_unit_idx = 0; - while (rate > 1000 && sizeUnits[rate_unit_idx + 1]) { - rate /= 1000.0; - rate_unit_idx++; - } - printf(" [%.2f %s/s] ", rate, sizeUnits[rate_unit_idx]); -} - -void print_progress(size_t count, size_t max, bool first_run, size_t elapsed_time) { - float progress = (float)count / max; - if (!first_run && progress < 0.01 && count > 0) - return; - - const int bar_width = 50; - int bar_length = progress * bar_width; - - printf("\rProgress: ["); - int i; - for (i = 0; i < bar_length; ++i) { - printf("#"); - } - for (i = bar_length; i < bar_width; ++i) { - printf(" "); - } - size_t totalSizeUnitId = 0; - double totalSize = max; - while (totalSize > 1000 && sizeUnits[totalSizeUnitId + 1]) { - totalSize /= 1000.0; - totalSizeUnitId++; - } - printf("] %.2f%% of %.2f %s", progress * 100, totalSize, sizeUnits[totalSizeUnitId]); - print_download_speed_info(count, elapsed_time); - if (progress == 1.0) - printf("\n"); - fflush(stdout); -} - -struct FtpFile { - const char* filename; - FILE* stream; - FtpFile() = delete; - FtpFile(const FtpFile&) = delete; - FtpFile& operator=(const FtpFile&) = delete; - FtpFile(const char* fname, FILE* str) : - filename(fname), - stream(str) {} - ~FtpFile() { - if (stream) { - fclose(stream); - } - if (!success) { - std::filesystem::remove(filename); - } - } - bool success = false; -}; - -void fileClose(FILE* file) { - if (file) { - fclose(file); - } -} - -static size_t file_write_callback(void* buffer, size_t size, size_t nmemb, void* stream) { - struct FtpFile* out = (struct FtpFile*)stream; - if (!out->stream) { - out->stream = fopen(out->filename, "wb"); - if (!out->stream) { - fprintf(stderr, "failure, cannot open file to write: %s\n", - out->filename); - return 0; - } - } - return fwrite(buffer, size, nmemb, out->stream); -} - -#define CHECK_CURL_CALL(call) \ - do { \ - CURLcode curlCode = call; \ - if (curlCode != CURLE_OK) { \ - SPDLOG_ERROR("curl error: {}. Error code: {}", curl_easy_strerror(curlCode), (int)curlCode); \ - return StatusCode::INTERNAL_ERROR; \ - } \ - } while (0) - -struct ProgressData { - time_t started_download; - time_t last_print_time; - bool fullDownloadPrinted = false; -}; -int progress_callback(void* clientp, - curl_off_t dltotal, - curl_off_t dlnow, - curl_off_t ultotal, - curl_off_t ulnow) { - ProgressData* pcs = reinterpret_cast(clientp); - if (dlnow == 0) { - pcs->started_download = time(NULL); - pcs->last_print_time = time(NULL); - } - time_t currentTime = time(NULL); - bool shouldPrintDueToTime = (currentTime - pcs->last_print_time >= 1); - if ((dltotal == dlnow) && dltotal < 10000) { - // Usually with first messages we don't get the full size and we don't want to print progress bar - // so we assume that until dltotal is less than 1000 we don't have full size - // otherwise we would print 100% progress bar - return 0; - } - // called multiple times, so we want to print progress bar only once reached 100% - if (pcs->fullDownloadPrinted) { - return 0; - } - if (!shouldPrintDueToTime && (dltotal != dlnow)) { - // we dont want to skip printing progress bar for the 100% but we don't want to spam stdout either - return 0; - } - pcs->fullDownloadPrinted = (dltotal == dlnow); - pcs->last_print_time = currentTime; - print_progress(dlnow, dltotal, (dlnow == 0), currentTime - pcs->started_download); - std::cout.flush(); - return 0; -} - -static Status downloadSingleFileWithCurl(const std::string& filePath, const std::string& url) { - // agent string required to avoid 403 Forbidden error on modelscope - std::string agentString = std::string(PROJECT_NAME) + "/" + std::string(PROJECT_VERSION); - - CURL* curl = nullptr; - CHECK_CURL_CALL(curl_global_init(CURL_GLOBAL_DEFAULT)); - auto globalCurlGuard = std::unique_ptr( - nullptr, [](void*) { curl_global_cleanup(); }); - curl = curl_easy_init(); - if (!curl) { - SPDLOG_ERROR("Failed to initialize cURL."); - return StatusCode::INTERNAL_ERROR; - } - auto handleGuard = std::unique_ptr(curl, curl_easy_cleanup); - // set impl options - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_URL, url.c_str())); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, file_write_callback)); - struct FtpFile ftpFile = {filePath.c_str(), NULL}; - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ftpFile)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USERAGENT, agentString.c_str())); - // progress bar options - ProgressData progressData; - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &progressData)); - // other options - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L)); - CHECK_CURL_CALL(curl_easy_setopt(curl, CURLOPT_USE_SSL, CURLUSESSL_ALL)); - CHECK_CURL_CALL(curl_easy_perform(curl)); - int32_t http_code = 0; - CHECK_CURL_CALL(curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code)); - SPDLOG_TRACE("HTTP response code: {}", http_code); - if (http_code != 200) { - SPDLOG_ERROR("Failed to download file from URL: {} HTTP response code: {}", url, http_code); - return StatusCode::PATH_INVALID; - } - ftpFile.success = true; - return StatusCode::OK; -} - std::variant> GGUFDownloader::createGGUFFilenamesToDownload(const std::string& ggufFilename) { std::vector filesToDownload; // we need to check if ggufFilename is of multipart type (contains 00001-of-N string) @@ -364,7 +188,7 @@ Status GGUFDownloader::downloadWithCurl(const std::string& hfEndpoint, const std // construct filepath auto filePath = FileSystem::joinPath({downloadPath, file}); SPDLOG_DEBUG("Downloading part {}/{} filename: {} url:{}", partNo, filesToDownload.size(), file, url); - auto status = downloadSingleFileWithCurl(filePath, url); + auto status = downloadFileWithCurl(url, filePath); if (!status.ok()) { return status; } diff --git a/src/pull_module/hf_pull_model_module.cpp b/src/pull_module/hf_pull_model_module.cpp index 24b39686cc..08c3b97b98 100644 --- a/src/pull_module/hf_pull_model_module.cpp +++ b/src/pull_module/hf_pull_model_module.cpp @@ -16,13 +16,19 @@ #include "hf_pull_model_module.hpp" #include +#include #include #include #include +#include + +#include #include "../config.hpp" +#include "src/filesystem/filesystem.hpp" #include "libgit2.hpp" #include "optimum_export.hpp" +#include "curl_downloader.hpp" #include "gguf_downloader.hpp" #include "../graph_export/graph_export.hpp" #include "../logging.hpp" @@ -110,7 +116,124 @@ Status HfPullModelModule::start(const ovms::Config& config) { return StatusCode::OK; } -Status HfPullModelModule::clone() const { +Status HfPullModelModule::resolveHfLoraFilenames() { + if (!std::holds_alternative(this->hfSettings.graphSettings)) { + return StatusCode::OK; + } + auto& graphSettings = std::get(this->hfSettings.graphSettings); + for (auto& adapter : graphSettings.loraAdapters) { + if (adapter.sourceType != LoraSourceType::HF_REPO) { + continue; + } + if (adapter.safetensorsFile.has_value()) { + continue; + } + if (adapter.resolvedSafetensorsFile.has_value()) { + continue; + } + // Query HF API to find the .safetensors file in the LoRA repo + std::string apiUrl = this->GetHfEndpoint() + "api/models/" + adapter.sourceLora; + SPDLOG_DEBUG("Querying HF API for LoRA adapter files: {}", apiUrl); + std::string responseBody; + std::string hfToken = this->GetHfToken(); + auto status = fetchUrlToString(apiUrl, hfToken, responseBody); + if (!status.ok()) { + SPDLOG_ERROR("Failed to query HF API for LoRA adapter: {}", adapter.sourceLora); + return status; + } + // Parse JSON response to find .safetensors files in siblings array + // Example: { "siblings": [{"rfilename": "file1.safetensors"}, ...] } + try { + auto json = nlohmann::json::parse(responseBody); + std::vector safetensorsFiles; + if (json.contains("siblings") && json["siblings"].is_array()) { + for (const auto& sibling : json["siblings"]) { + if (sibling.contains("rfilename") && sibling["rfilename"].is_string()) { + const std::string& filename = sibling["rfilename"].get_ref(); + if (endsWith(filename, ".safetensors")) { + safetensorsFiles.push_back(filename); + } + } + } + } + if (safetensorsFiles.empty()) { + SPDLOG_ERROR("No .safetensors files found via HF API for LoRA adapter: {}", adapter.sourceLora); + return StatusCode::PATH_INVALID; + } + if (safetensorsFiles.size() > 1) { + SPDLOG_ERROR("Multiple .safetensors files found for LoRA adapter: {}. Use @filename to specify.", adapter.sourceLora); + return StatusCode::PATH_INVALID; + } + adapter.resolvedSafetensorsFile = safetensorsFiles[0]; + SPDLOG_DEBUG("Resolved LoRA safetensors file for {}: {}", adapter.sourceLora, adapter.resolvedSafetensorsFile.value()); + } catch (const nlohmann::json::exception& e) { + SPDLOG_ERROR("Failed to parse HF API JSON response for LoRA adapter {}: {}", adapter.sourceLora, e.what()); + return StatusCode::INTERNAL_ERROR; + } + } + return StatusCode::OK; +} + +Status HfPullModelModule::pullLoraAdapters(const std::string& graphDirectory) { + if (!std::holds_alternative(this->hfSettings.graphSettings)) { + return StatusCode::OK; + } + auto status = this->resolveHfLoraFilenames(); + if (!status.ok()) { + return status; + } + const auto& graphSettings = std::get(this->hfSettings.graphSettings); + for (const auto& adapter : graphSettings.loraAdapters) { + if (adapter.sourceType == LoraSourceType::LOCAL_FILE) { + std::cout << "LoRA adapter: " << adapter.alias << " using local file: " << adapter.sourceLora << std::endl; + continue; + } + std::string loraDownloadPath; + std::string loraUrl; + std::string authTokenHF; + if (adapter.sourceType == LoraSourceType::HF_REPO) { + loraDownloadPath = FileSystem::joinPath({graphDirectory, "loras", adapter.sourceLora}); + loraUrl = this->GetHfEndpoint() + adapter.sourceLora + "/resolve/main/" + adapter.effectiveSafetensorsFile().value(); + authTokenHF = this->GetHfToken(); + } else if (adapter.sourceType == LoraSourceType::DIRECT_URL) { + loraDownloadPath = FileSystem::joinPath({graphDirectory, "loras", adapter.alias}); + loraUrl = adapter.sourceLora; + } else { + SPDLOG_ERROR("Unknown LoRA source type for adapter: {}", adapter.alias); + return StatusCode::INTERNAL_ERROR; + } + auto loraFilePath = FileSystem::joinPath({loraDownloadPath, adapter.effectiveSafetensorsFile().value()}); + if (!this->hfSettings.overwriteModels && std::filesystem::exists(loraFilePath)) { + std::cout << "LoRA adapter: " << adapter.alias << " already exists, skipping download." << std::endl; + continue; + } + if (!std::filesystem::exists(loraDownloadPath)) { + if (!std::filesystem::create_directories(loraDownloadPath)) { + SPDLOG_ERROR("Failed to create LoRA directory: {}", loraDownloadPath); + return StatusCode::DIRECTORY_NOT_CREATED; + } + } + auto loraTmpFilePath = loraFilePath + ".tmp"; + std::error_code ec; + std::filesystem::remove(loraTmpFilePath, ec); + status = downloadFileWithCurl(loraUrl, loraTmpFilePath, authTokenHF); + if (!status.ok()) { + SPDLOG_ERROR("Failed to download LoRA adapter: {} from: {}", adapter.alias, loraUrl); + std::filesystem::remove(loraTmpFilePath, ec); + return status; + } + std::filesystem::rename(loraTmpFilePath, loraFilePath, ec); + if (ec) { + SPDLOG_ERROR("Failed to rename LoRA temp file: {} -> {}: {}", loraTmpFilePath, loraFilePath, ec.message()); + std::filesystem::remove(loraTmpFilePath, ec); + return StatusCode::INTERNAL_ERROR; + } + std::cout << "LoRA adapter: " << adapter.alias << " downloaded to: " << loraDownloadPath << std::endl; + } + return StatusCode::OK; +} + +Status HfPullModelModule::clone() { std::string graphDirectory = ""; std::unique_ptr downloader; std::variant> guardOrError; @@ -150,6 +273,12 @@ Status HfPullModelModule::clone() const { std::cout << "Draft model: " << GraphExport::getDraftModelDirectoryName(graphSettings.draftModelDirName.value()) << " downloaded to: " << GraphExport::getDraftModelDirectoryPath(graphDirectory, graphSettings.draftModelDirName.value()) << std::endl; } + // Image gen with LoRA adapters case - resolve filenames and download safetensors files + status = this->pullLoraAdapters(graphDirectory); + if (!status.ok()) { + return status; + } + GraphExport graphExporter; status = graphExporter.createServableConfig(graphDirectory, this->hfSettings, true); // when downloading from HF we always create config file, but when using local model with --task we create config in memory without writing to file if (!status.ok()) { diff --git a/src/pull_module/hf_pull_model_module.hpp b/src/pull_module/hf_pull_model_module.hpp index be42887b39..296300b474 100644 --- a/src/pull_module/hf_pull_model_module.hpp +++ b/src/pull_module/hf_pull_model_module.hpp @@ -35,10 +35,14 @@ class HfPullModelModule : public Module { Status start(const ovms::Config& config) override; void shutdown() override; - Status clone() const; + Status clone(); static const std::string GIT_SERVER_CONNECT_TIMEOUT_ENV; static const std::string GIT_SERVER_TIMEOUT_ENV; static const std::string GIT_SSL_CERT_LOCATIONS_ENV; + +protected: + Status resolveHfLoraFilenames(); + Status pullLoraAdapters(const std::string& graphDirectory); }; std::variant> createLibGitGuard(); diff --git a/src/servable_name_checker.hpp b/src/servable_name_checker.hpp index dd672ff608..2b8e9baf86 100644 --- a/src/servable_name_checker.hpp +++ b/src/servable_name_checker.hpp @@ -17,6 +17,7 @@ #include #include +#include namespace ovms { @@ -41,6 +42,7 @@ class ServableNameChecker { public: virtual ~ServableNameChecker() = default; virtual bool servableExists(const std::string& name, ServableQueryType check = ServableQueryType::All) const = 0; + virtual bool aliasesConflict(const std::vector& aliases, const std::string& ownGraphName) const = 0; }; } // namespace ovms diff --git a/src/server.cpp b/src/server.cpp index d2fc34a1a5..626186d0ce 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -409,7 +409,7 @@ Status Server::startModules(ovms::Config& config) { if (config.getServerSettings().serverMode == HF_PULL_MODE) { INSERT_MODULE(HF_MODEL_PULL_MODULE_NAME, it); START_MODULE(it); - auto hfModule = dynamic_cast(it->second.get()); + auto hfModule = dynamic_cast(it->second.get()); status = hfModule->clone(); return status; } @@ -443,10 +443,7 @@ Status Server::startModules(ovms::Config& config) { if (config.getServerSettings().serverMode == HF_PULL_AND_START_MODE) { INSERT_MODULE(HF_MODEL_PULL_MODULE_NAME, it); START_MODULE(it); - if (!status.ok()) { - return status; - } - auto hfModule = dynamic_cast(it->second.get()); + auto hfModule = dynamic_cast(it->second.get()); status = hfModule->clone(); // Return only on clone error; otherwise start the rest of modules if (!status.ok()) diff --git a/src/stringutils.cpp b/src/stringutils.cpp index 1f369902cc..b772f45c40 100644 --- a/src/stringutils.cpp +++ b/src/stringutils.cpp @@ -306,4 +306,21 @@ void escapeSpecialCharacters(std::string& text) { text = std::move(escaped); } +// TODO move to image gen cli parser file? +bool isLocalFilePath(const std::string& path) { + if (path.empty()) { + return false; + } + if (path[0] == '/') { + return true; + } + if (path.size() >= 2 && (path.substr(0, 2) == "./" || path.substr(0, 2) == ".\\")) { + return true; + } + if (path.size() >= 3 && std::isalpha(static_cast(path[0])) && path[1] == ':' && (path[2] == '\\' || path[2] == '/')) { + return true; + } + return false; +} + } // namespace ovms diff --git a/src/stringutils.hpp b/src/stringutils.hpp index ac812702b9..ade808d0ba 100644 --- a/src/stringutils.hpp +++ b/src/stringutils.hpp @@ -131,4 +131,6 @@ bool stringsOverlap(const std::string& lhs, const std::string& rhs); void escapeSpecialCharacters(std::string& text); +bool isLocalFilePath(const std::string& path); + } // namespace ovms diff --git a/src/test/graph_export_test.cpp b/src/test/graph_export_test.cpp index 0bbd684646..ed79061a63 100644 --- a/src/test/graph_export_test.cpp +++ b/src/test/graph_export_test.cpp @@ -13,6 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** +#include +#include #include #include @@ -536,12 +538,6 @@ class GraphCreationTest : public TestWithTempDir { TestWithTempDir::TearDown(); } - // Removes # OpenVINO Model Server REPLACE_PROJECT_VERSION comment added for debug purpose in graph export at the begging of graph.pbtxt - // This string differs per build and setup - std::string removeVersionString(std::string input) { - return input.erase(0, input.find("\n") + 1); - } - std::string getVersionString() { std::stringstream expected; expected << "# File created with: " << PROJECT_NAME << " " << PROJECT_VERSION << std::endl; diff --git a/src/test/light_test_utils.hpp b/src/test/light_test_utils.hpp index 0a60fddcf7..64084f6ab1 100644 --- a/src/test/light_test_utils.hpp +++ b/src/test/light_test_utils.hpp @@ -18,3 +18,8 @@ #include std::string GetFileContents(const std::string& filePath); bool createConfigFileWithContent(const std::string& content, std::string filename = "/tmp/ovms_config_file.json"); + +// Removes the version comment line from the beginning of graph.pbtxt content +inline std::string removeVersionString(std::string input) { + return input.erase(0, input.find("\n") + 1); +} diff --git a/src/test/lora_graph_export_test.cpp b/src/test/lora_graph_export_test.cpp new file mode 100644 index 0000000000..023209ae58 --- /dev/null +++ b/src/test/lora_graph_export_test.cpp @@ -0,0 +1,888 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include +#include + +#include +#include + +#include "light_test_utils.hpp" +#include "test_with_temp_dir.hpp" +#include "../capi_frontend/server_settings.hpp" +#include "../graph_export/graph_export.hpp" +#include "../graph_export/image_generation_graph_cli_parser.hpp" +#include "src/filesystem/filesystem.hpp" +#include "../status.hpp" + +class LoraGraphCreationTest : public TestWithTempDir {}; + +// ===================== LoRA Graph Export Tests ===================== + +const std::string expectedImageGenWithOneLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/juliensimon/sd-pokemon-lora/pytorch_lora_weights.safetensors" mode: DYNAMIC } + } + } +} + +)"; + +const std::string expectedImageGenWithTwoLoras = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "GPU" + max_resolution: "1024x1024" + lora_adapters { alias: "pokemon" path: "loras/juliensimon/sd-pokemon-lora/model.safetensors" mode: DYNAMIC } + lora_adapters { alias: "anime-style" path: "loras/org2/anime-lora/weights.safetensors" mode: DYNAMIC } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationWithOneLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "juliensimon/sd-pokemon-lora", "pytorch_lora_weights.safetensors"}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithOneLora, removeVersionString(graphContents)) << graphContents; +} + +TEST_F(LoraGraphCreationTest, imageGenerationWithTwoLoras) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + hfSettings.exportSettings.targetDevice = "GPU"; + imageGenerationGraphSettings.maxResolution = "1024x1024"; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "juliensimon/sd-pokemon-lora", "model.safetensors"}); + imageGenerationGraphSettings.loraAdapters.push_back({"anime-style", "org2/anime-lora", "weights.safetensors"}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithTwoLoras, removeVersionString(graphContents)) << graphContents; +} + +const std::string expectedImageGenerationGraphContentsDefault = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationNoLorasRemainsUnchanged) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenerationGraphContentsDefault, removeVersionString(graphContents)) << graphContents; +} + +// ===================== LoRA CLI-to-Settings Tests ===================== + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=juliensimon/sd-pokemon-lora"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + EXPECT_FALSE(graphSettings.loraAdapters[0].safetensorsFile.has_value()); +} + +TEST(ImageGenCLILoraParsingTest, MissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "juliensimon/sd-pokemon-lora"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAliasAndFilename) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=juliensimon/sd-pokemon-lora@custom_lora.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "custom_lora.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, MultipleLoras) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org1/repo1,anime=org2/repo2@weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 2); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "org1/repo1"); + EXPECT_FALSE(graphSettings.loraAdapters[0].safetensorsFile.has_value()); + EXPECT_EQ(graphSettings.loraAdapters[1].alias, "anime"); + EXPECT_EQ(graphSettings.loraAdapters[1].sourceLora, "org2/repo2"); + EXPECT_EQ(graphSettings.loraAdapters[1].safetensorsFile.value(), "weights.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, EmptySourceLorasProducesNoAdapters) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = ""; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 0); +} + +TEST(ImageGenCLILoraParsingTest, InvalidEmptyAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidEmptyFilenameAfterAt) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/repo@"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, MissingAliasWithFilenameThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "org1/repo1@special.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +// ===================== LoRA Source Type Tests ===================== + +TEST(ImageGenCLILoraParsingTest, DirectUrlWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "pytorch_lora_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::DIRECT_URL); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlHttpWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=http://example.com/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::DIRECT_URL); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "weights.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlMissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "https://example.com/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, DirectUrlNotSafetensorsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://example.com/model.bin"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +class ImageGenCLILoraParsingWithTempDir : public TestWithTempDir {}; + +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileWithAlias) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "test_weights.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + hfSettings.sourceLoras = "pokemon=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, tmpFile); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "test_weights.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::LOCAL_FILE); +} + +TEST_F(ImageGenCLILoraParsingWithTempDir, MixedSourceTypes) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "local.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + hfSettings.sourceLoras = "hf=org/repo,url=https://example.com/remote.safetensors,local=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 3); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::HF_REPO); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "hf"); + EXPECT_EQ(graphSettings.loraAdapters[1].sourceType, ovms::LoraSourceType::DIRECT_URL); + EXPECT_EQ(graphSettings.loraAdapters[1].alias, "url"); + EXPECT_EQ(graphSettings.loraAdapters[1].safetensorsFile.value(), "remote.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[2].sourceType, ovms::LoraSourceType::LOCAL_FILE); + EXPECT_EQ(graphSettings.loraAdapters[2].alias, "local"); + EXPECT_EQ(graphSettings.loraAdapters[2].safetensorsFile.value(), "local.safetensors"); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileMissingAliasThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "/tmp/some_weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileNotSafetensorsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=/tmp/model.bin"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, LocalFileDoesNotExistThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=/nonexistent/path/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +// ===================== Graph Export with Different Source Types ===================== + +const std::string expectedImageGenWithUrlLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/pokemon/pytorch_lora_weights.safetensors" mode: DYNAMIC } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationWithUrlLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "https://huggingface.co/juliensimon/sd-pokemon-lora/resolve/main/pytorch_lora_weights.safetensors", "pytorch_lora_weights.safetensors", ovms::LoraSourceType::DIRECT_URL}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithUrlLora, removeVersionString(graphContents)) << graphContents; +} + +const std::string expectedImageGenWithLocalLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "/path/to/weights.safetensors" mode: DYNAMIC } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationWithLocalLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "/path/to/weights.safetensors", "weights.safetensors", ovms::LoraSourceType::LOCAL_FILE}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithLocalLora, removeVersionString(graphContents)) << graphContents; +} + +TEST_F(LoraGraphCreationTest, imageGenerationHfRepoLoraWithoutFilenameReturnsError) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + // HF_REPO adapter without @filename and without pull — safetensorsFile is nullopt + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "juliensimon/sd-pokemon-lora", std::nullopt, ovms::LoraSourceType::HF_REPO}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID); +} + +// ===================== Composite LoRA Tests ===================== + +TEST(ImageGenCLILoraParsingTest, CompositeLoraBasic) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora,pokemon_anime=@pokemon+@anime"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 2); + ASSERT_EQ(graphSettings.compositeLoraAdapters.size(), 1); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].alias, "pokemon_anime"); + ASSERT_EQ(graphSettings.compositeLoraAdapters[0].components.size(), 2); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[0].adapterAlias, "pokemon"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[0].alpha, 1.0f); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[1].adapterAlias, "anime"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[1].alpha, 1.0f); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraWithAlphas) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora,blend=@pokemon:0.7+@anime:0.5"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.compositeLoraAdapters.size(), 1); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].alias, "blend"); + ASSERT_EQ(graphSettings.compositeLoraAdapters[0].components.size(), 2); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[0].alpha, 0.7f); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[1].alpha, 0.5f); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraUnknownRefThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,blend=@pokemon+@nonexistent"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, CompositeLoraInvalidComponentThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,blend=@pokemon+noatsign"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAlpha) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora:0.75"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "org/pokemon-lora"); + ASSERT_TRUE(graphSettings.loraAdapters[0].alpha.has_value()); + EXPECT_FLOAT_EQ(graphSettings.loraAdapters[0].alpha.value(), 0.75f); +} + +TEST(ImageGenCLILoraParsingTest, SingleLoraWithAlphaOne) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora:1.0"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + ASSERT_TRUE(graphSettings.loraAdapters[0].alpha.has_value()); + EXPECT_FLOAT_EQ(graphSettings.loraAdapters[0].alpha.value(), 1.0f); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAlphaThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora:abc"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAlphaPartialFloatThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora:0.5abc"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAlphaWithFilenameThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/repo@some.safetensors:wikingowie"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAlphaMultipleDotsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/repo@some.safetensors:2.5.32"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, UrlLoraWithAlpha) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://huggingface.co/org/repo/resolve/main/weights.safetensors:0.5"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "https://huggingface.co/org/repo/resolve/main/weights.safetensors"); + ASSERT_TRUE(graphSettings.loraAdapters[0].alpha.has_value()); + EXPECT_FLOAT_EQ(graphSettings.loraAdapters[0].alpha.value(), 0.5f); +} + +TEST(ImageGenCLILoraParsingTest, UrlLoraWithoutAlphaPreservesDefault) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=https://huggingface.co/org/repo/resolve/main/weights.safetensors"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "https://huggingface.co/org/repo/resolve/main/weights.safetensors"); + EXPECT_FALSE(graphSettings.loraAdapters[0].alpha.has_value()); +} + +TEST(ImageGenCLILoraParsingTest, NPURejectsMultiLoraWithoutComposites) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora"; + hfSettings.exportSettings.targetDevice = "NPU"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, NPUAllowsCompositeAdapters) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "pokemon=org/pokemon-lora,anime=org/anime-lora,blend=@pokemon:0.5+@anime:0.5"; + hfSettings.exportSettings.targetDevice = "NPU"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_NO_THROW(parser.prepare(serverSettings, hfSettings, "test_model")); +} + +const std::string expectedImageGenWithCompositeLora = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "pokemon" path: "loras/org/pokemon-lora/weights.safetensors" mode: DYNAMIC } + lora_adapters { alias: "anime" path: "loras/org/anime-lora/weights.safetensors" mode: DYNAMIC } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 0.7 } + components { adapter_alias: "anime" alpha: 0.5 } + } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationWithCompositeLora) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "org/pokemon-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.loraAdapters.push_back({"anime", "org/anime-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.compositeLoraAdapters.push_back({"blend", {{"pokemon", 0.7f}, {"anime", 0.5f}}}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenWithCompositeLora, removeVersionString(graphContents)) << graphContents; +} + +// ===================== LoRA Alias Validation Tests ===================== + +TEST(ImageGenCLILoraParsingTest, InvalidAliasWithSpacesThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "my pokemon=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAliasWithSlashThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "my/pokemon=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, InvalidAliasWithSpecialCharsThrows) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "poke@mon=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST(ImageGenCLILoraParsingTest, ValidAliasWithHyphensUnderscoresDots) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "my-lora_v1.0=org/repo"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "my-lora_v1.0"); +} + +// ===================== LoRA Local File Path Tests ===================== + +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileAbsoluteUnixPath) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "model.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + hfSettings.sourceLoras = "pokemon=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::LOCAL_FILE); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "model.safetensors"); +} + +#ifdef _WIN32 +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileWindowsAbsolutePath) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "model.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + // On Windows, directoryPath uses native path with backslashes + hfSettings.sourceLoras = "pokemon=" + tmpFile; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::LOCAL_FILE); +} + +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileWindowsRelativeDotBackslash) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + // Create file at CWD-relative path + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "model.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + // Use .\ relative path (Windows-style) + hfSettings.sourceLoras = "pokemon=.\\" + std::filesystem::path(tmpFile).filename().string(); + ovms::ImageGenerationGraphCLIParser parser; + // This will throw because relative path won't resolve to existing file from CWD, + // but it should at least be detected as LOCAL_FILE source type (i.e. not HF_REPO) + EXPECT_THROW(parser.prepare(serverSettings, hfSettings, "test_model"), std::invalid_argument); +} + +TEST_F(ImageGenCLILoraParsingWithTempDir, LocalFileWindowsAbsolutePathWithAlpha) { + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + std::string tmpFile = ovms::FileSystem::joinPath({this->directoryPath, "model.safetensors"}); + { + std::ofstream f(tmpFile); + f << "test"; + } + // Windows path with alpha: C:\path\to\model.safetensors:0.6 + // The drive letter colon (C:) must not be confused with alpha separator + hfSettings.sourceLoras = "pokemon=" + tmpFile + ":0.6"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "test_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(graphSettings.loraAdapters.size(), 1); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::LOCAL_FILE); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "model.safetensors"); + ASSERT_TRUE(graphSettings.loraAdapters[0].alpha.has_value()); + EXPECT_FLOAT_EQ(graphSettings.loraAdapters[0].alpha.value(), 0.6f); +} +#endif + +// ===================== Full Composite LoRA End-to-End CLI Test ===================== + +TEST(ImageGenCLILoraParsingTest, FullCompositeWithAlphasAndTwoLoras) { + // Two individual LoRAs + one composite referencing both with explicit alphas + ovms::ServerSettingsImpl serverSettings; + serverSettings.serverMode = ovms::HF_PULL_MODE; + ovms::HFSettingsImpl hfSettings; + hfSettings.sourceLoras = "landscape=civitai/landscapes-lora@Fantastic_Landscape.safetensors," + "portrait=org/portrait-lora," + "scenic_blend=@landscape:0.3+@portrait:0.8"; + ovms::ImageGenerationGraphCLIParser parser; + parser.prepare(serverSettings, hfSettings, "sd_model"); + auto& graphSettings = std::get(hfSettings.graphSettings); + + // Verify individual LoRAs + ASSERT_EQ(graphSettings.loraAdapters.size(), 2); + EXPECT_EQ(graphSettings.loraAdapters[0].alias, "landscape"); + EXPECT_EQ(graphSettings.loraAdapters[0].sourceLora, "civitai/landscapes-lora"); + EXPECT_EQ(graphSettings.loraAdapters[0].safetensorsFile.value(), "Fantastic_Landscape.safetensors"); + EXPECT_EQ(graphSettings.loraAdapters[1].alias, "portrait"); + EXPECT_EQ(graphSettings.loraAdapters[1].sourceLora, "org/portrait-lora"); + + // Verify composite + ASSERT_EQ(graphSettings.compositeLoraAdapters.size(), 1); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].alias, "scenic_blend"); + ASSERT_EQ(graphSettings.compositeLoraAdapters[0].components.size(), 2); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[0].adapterAlias, "landscape"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[0].alpha, 0.3f); + EXPECT_EQ(graphSettings.compositeLoraAdapters[0].components[1].adapterAlias, "portrait"); + EXPECT_FLOAT_EQ(graphSettings.compositeLoraAdapters[0].components[1].alpha, 0.8f); +} + +const std::string expectedImageGenFullComposite = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "CPU" + lora_adapters { alias: "landscape" path: "loras/civitai/landscapes-lora/Fantastic_Landscape.safetensors" mode: DYNAMIC } + lora_adapters { alias: "portrait" path: "loras/org/portrait-lora/weights.safetensors" mode: DYNAMIC } + composite_lora_adapters { + alias: "scenic_blend" + components { adapter_alias: "landscape" alpha: 0.3 } + components { adapter_alias: "portrait" alpha: 0.8 } + } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationFullCompositeWithAlphas) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"landscape", "civitai/landscapes-lora", "Fantastic_Landscape.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.loraAdapters.push_back({"portrait", "org/portrait-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO}); + imageGenerationGraphSettings.compositeLoraAdapters.push_back({"scenic_blend", {{"landscape", 0.3f}, {"portrait", 0.8f}}}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenFullComposite, removeVersionString(graphContents)) << graphContents; +} + +const std::string expectedImageGenNpuStatic = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + device: "NPU" + lora_adapters { alias: "pokemon" path: "loras/org/pokemon-lora/weights.safetensors" alpha: 0.8 mode: STATIC } + } + } +} + +)"; + +TEST_F(LoraGraphCreationTest, imageGenerationNpuAutoStaticMode) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + hfSettings.exportSettings.targetDevice = "NPU"; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.loraAdapters.push_back({"pokemon", "org/pokemon-lora", "weights.safetensors", ovms::LoraSourceType::HF_REPO, 0.8f}); + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenNpuStatic, removeVersionString(graphContents)) << graphContents; +} diff --git a/src/test/ovmsconfig_test.cpp b/src/test/ovmsconfig_test.cpp index b59a2d985c..125da560ea 100644 --- a/src/test/ovmsconfig_test.cpp +++ b/src/test/ovmsconfig_test.cpp @@ -577,6 +577,60 @@ TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_MaxNumInferenceStepsZer EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); } +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyAlias) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"=org/repo", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyRepo) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"alias=", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + +TEST(OvmsGraphConfigTest, negativeImageGenerationGraph_SourceLorasEmptyFilenameAfterAt) { + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)"some/model", + (char*)"--model_repository_path", + (char*)"/some/path", + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras", + (char*)"pokemon=org/repo@", + }; + int arg_count = 10; + ConstructorEnabledConfig config; + EXPECT_THROW(config.parse(arg_count, n_argv), std::invalid_argument); +} + TEST_F(OvmsConfigDeathTest, hfBadEmbeddingsGraphParameter) { char* n_argv[] = { "ovms", @@ -1772,6 +1826,39 @@ TEST(OvmsGraphConfigTest, positiveAllChangedImageGeneration) { ASSERT_EQ(exportSettings.pluginConfig.manualString.value(), "{\"SOME_KEY\":\"SOME_VALUE\"}"); } +TEST(OvmsGraphConfigTest, positiveImageGenerationWithSourceLoras) { + std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; + std::string downloadPath = "test/repository"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)modelName.c_str(), + (char*)"--model_repository_path", + (char*)downloadPath.c_str(), + (char*)"--task", + (char*)"image_generation", + (char*)"--source_loras=pokemon=juliensimon/sd-pokemon-lora@weights.safetensors,anime=org/anime-lora", + }; + + int arg_count = 9; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + auto& hfSettings = config.getServerSettings().hfSettings; + ASSERT_EQ(hfSettings.task, ovms::IMAGE_GENERATION_GRAPH); + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters.size(), 2); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].alias, "pokemon"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].sourceLora, "juliensimon/sd-pokemon-lora"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].safetensorsFile.value(), "weights.safetensors"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[0].sourceType, ovms::LoraSourceType::HF_REPO); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].alias, "anime"); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].sourceLora, "org/anime-lora"); + ASSERT_FALSE(imageGenerationGraphSettings.loraAdapters[1].safetensorsFile.has_value()); + ASSERT_EQ(imageGenerationGraphSettings.loraAdapters[1].sourceType, ovms::LoraSourceType::HF_REPO); +} + TEST(OvmsGraphConfigTest, positiveDefaultImageGeneration) { std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; std::string downloadPath = "test/repository"; diff --git a/src/test/pull_hf_model_test.cpp b/src/test/pull_hf_model_test.cpp index e1ab3f9ff3..e8478103fc 100644 --- a/src/test/pull_hf_model_test.cpp +++ b/src/test/pull_hf_model_test.cpp @@ -36,6 +36,7 @@ #include #include +#include #include "src/utils/env_guard.hpp" #include "src/test/light_test_utils.hpp" @@ -48,6 +49,7 @@ #include "src/pull_module/optimum_export.hpp" #include "src/servables_config_manager_module/listmodels.hpp" #include "src/modelextensions.hpp" +#include "src/capi_frontend/server_settings.hpp" #include "../module.hpp" #include "../server.hpp" @@ -93,12 +95,6 @@ class HfPull : public TestWithTempDir { RemoveReadonlyFileAttributeFromDir(this->directoryPath); TestWithTempDir::TearDown(); } - - // Removes # OpenVINO Model Server REPLACE_PROJECT_VERSION comment added for debug purpose in graph export at the begging of graph.pbtxt - // This string differs per build and setup - std::string removeVersionString(std::string input) { - return input.erase(0, input.find("\n") + 1); - } }; class HfPullCache : public HfPull { @@ -384,9 +380,7 @@ TEST_F(HfPullCache, Resume) { // Fails because we want clean and it has the graph.pbtxt after download ASSERT_EQ(hfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_STATUS_UNCLEAN); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); std::error_code ec; ec.clear(); @@ -1243,9 +1237,7 @@ TEST(HfDownloaderClassTest, RepositoryStatusCheckErrors) { // Fails without libgit init ASSERT_EQ(hfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_LIBGIT2_NOT_INITIALIZED); ASSERT_EQ(hfDownloader->CheckRepositoryStatus(false).getCode(), ovms::StatusCode::HF_GIT_LIBGIT2_NOT_INITIALIZED); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); EXPECT_EXIT({ std::unique_ptr hfDownloader = std::make_unique(modelName, ovms::IModelDownloader::getGraphDirectory(downloadPath, modelName), hfEndpoint, hfToken, httpProxy, false); @@ -1263,9 +1255,7 @@ TEST(HfDownloaderClassTest, RepositoryStatusCheckErrors) { std::unique_ptr existingHfDownloader = std::make_unique(modelName, downloadPath, hfEndpoint, hfToken, httpProxy, false); ASSERT_EQ(existingHfDownloader->CheckRepositoryStatus(true).getCode(), ovms::StatusCode::HF_GIT_STATUS_FAILED); ASSERT_EQ(existingHfDownloader->CheckRepositoryStatus(false).getCode(), ovms::StatusCode::HF_GIT_STATUS_FAILED); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } TEST(HfDownloaderClassTest, CloneCancellationFollowsServerShutdownRequest) { @@ -1575,9 +1565,7 @@ TEST(Libgt2InitGuardTest, LfsFilterCaptureDefaultResumeOptions) { } EXPECT_THAT(output, ::testing::HasSubstr("[INFO] LFS resume: attempts=5 interval=10 s")); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } TEST(Libgt2InitGuardTest, LfsFilterCaptureNonDefaultResumeOptions) { @@ -1600,9 +1588,7 @@ TEST(Libgt2InitGuardTest, LfsFilterCaptureNonDefaultResumeOptions) { } EXPECT_THAT(output, ::testing::HasSubstr("[INFO] LFS resume: attempts=3 interval=20 s")); - exit(0); - }, - ::testing::ExitedWithCode(0), ""); + exit(0); }, ::testing::ExitedWithCode(0), ""); } TEST_F(HfDownloaderHfEnvTest, Methods) { @@ -1822,3 +1808,360 @@ TEST(ServerModulesBehaviorTests, PullAndStartModeErrorAndExpectFailAndCheckOther ASSERT_NE(server.getModule(ovms::SERVABLE_MANAGER_MODULE_NAME), nullptr); // expected to be started ASSERT_EQ(server.getModule(ovms::SERVABLES_CONFIG_MANAGER_MODULE_NAME), nullptr); } + +// ===================== LoRA Pull Module Tests ===================== + +class TestHfPullModelModuleForLora : public ovms::HfPullModelModule { +public: + ovms::HFSettingsImpl& getHfSettings() { return this->hfSettings; } + ovms::Status testResolveHfLoraFilenames() { return this->resolveHfLoraFilenames(); } + ovms::Status testPullLoraAdapters(const std::string& graphDirectory) { return this->pullLoraAdapters(graphDirectory); } +}; + +class HfPullModelModuleLoraTest : public TestWithTempDir {}; + +TEST_F(HfPullModelModuleLoraTest, ResolveHfLoraFilenames) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF API resolution)"; + } + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "pokemon"; + adapter.sourceLora = "juliensimon/sd-pokemon-lora"; + adapter.sourceType = ovms::LoraSourceType::HF_REPO; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testResolveHfLoraFilenames(); + ASSERT_TRUE(status.ok()) << status.string(); + + const auto& resolved = std::get(settings.graphSettings); + ASSERT_EQ(resolved.loraAdapters.size(), 1); + EXPECT_FALSE(resolved.loraAdapters[0].safetensorsFile.has_value()); + EXPECT_EQ(resolved.loraAdapters[0].resolvedSafetensorsFile.value(), "pytorch_lora_weights.safetensors"); + EXPECT_EQ(resolved.loraAdapters[0].effectiveSafetensorsFile().value(), "pytorch_lora_weights.safetensors"); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersFromHfRepo) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF download)"; + } + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "pokemon"; + adapter.sourceLora = "juliensimon/sd-pokemon-lora"; + adapter.safetensorsFile = "pytorch_lora_weights.safetensors"; // explicit filename — skips HF API resolve + adapter.sourceType = ovms::LoraSourceType::HF_REPO; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); + + auto loraFilePath = ovms::FileSystem::joinPath({this->directoryPath, "loras", "juliensimon/sd-pokemon-lora", "pytorch_lora_weights.safetensors"}); + ASSERT_TRUE(std::filesystem::exists(loraFilePath)) << loraFilePath; + EXPECT_GT(std::filesystem::file_size(loraFilePath), 0); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersSkipsLocalFile) { + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl graphSettings; + ovms::LoraAdapterSettings adapter; + adapter.alias = "local_lora"; + adapter.sourceLora = "/some/local/path/model.safetensors"; + adapter.safetensorsFile = "model.safetensors"; + adapter.sourceType = ovms::LoraSourceType::LOCAL_FILE; + graphSettings.loraAdapters.push_back(adapter); + settings.graphSettings = graphSettings; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); + // No files should have been downloaded to the temp directory + EXPECT_TRUE(std::filesystem::is_empty(this->directoryPath)); +} + +TEST_F(HfPullModelModuleLoraTest, PullLoraAdaptersNonImageGenGraphIsNoOp) { + TestHfPullModelModuleForLora module; + auto& settings = module.getHfSettings(); + settings.task = ovms::TEXT_GENERATION_GRAPH; + settings.graphSettings = ovms::TextGenGraphSettingsImpl{}; + + auto status = module.testPullLoraAdapters(this->directoryPath); + ASSERT_TRUE(status.ok()) << status.string(); +} + +class HfDownloaderPullHfModel : public HfPull {}; + +// Full-flow test: download SD model + LoRA via --pull mode, verify files and graph.pbtxt. +// This exercises: CLI parsing -> source_loras -> HF resolution -> LoRA download -> graph.pbtxt generation. +// Runtime clone()+LoRA behavior is guaranteed by the GenAI API: clone() "reuses underlying models" +// which share the AdapterController. Adapters are selected per-request via generate() properties. +TEST_F(HfDownloaderPullHfModel, DownloadImageGenModelWithLoRA) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + const char* hfToken = std::getenv("HF_TOKEN"); + if (!hfToken || std::string(hfToken).empty()) { + GTEST_SKIP() << "Skipping: HF_TOKEN not set (required for HF LoRA download)"; + } + this->filesToPrintInCaseOfFailure.emplace_back("graph.pbtxt"); + std::string modelName = "OpenVINO/stable-diffusion-v1-5-int8-ov"; + std::string downloadPath = ovms::FileSystem::joinPath({this->directoryPath, "repository"}); + std::string task = "image_generation"; + std::string sourceLoras = "pokemon=juliensimon/sd-pokemon-lora@pytorch_lora_weights.safetensors"; + ::SetUpServerForDownloadWithLoras(this->t, this->server, modelName, downloadPath, task, sourceLoras); + + std::string basePath = ovms::FileSystem::joinPath({downloadPath, "OpenVINO", "stable-diffusion-v1-5-int8-ov"}); + std::string graphPath = ovms::FileSystem::appendSlash(basePath) + "graph.pbtxt"; + + // Verify model was downloaded + ASSERT_TRUE(std::filesystem::exists(basePath)) << basePath; + ASSERT_TRUE(std::filesystem::exists(graphPath)) << graphPath; + + // Verify LoRA adapter was downloaded + std::string loraDir = ovms::FileSystem::joinPath({basePath, "loras", "juliensimon", "sd-pokemon-lora"}); + auto loraFiles = searchFilesRecursively(loraDir, {"pytorch_lora_weights.safetensors"}); + ASSERT_FALSE(loraFiles.empty()) << "LoRA .safetensors not found in: " << loraDir; + + // Verify graph.pbtxt contains the LoRA adapter entry + std::string graphContents = GetFileContents(graphPath); + EXPECT_NE(graphContents.find("lora_adapters"), std::string::npos) << "graph.pbtxt should contain lora_adapters"; + EXPECT_NE(graphContents.find("pokemon"), std::string::npos) << "graph.pbtxt should reference pokemon alias"; +} + +// ===================== Full Image Generation with Pull + LoRA Integration Test ===================== +// Single test that: +// 1. Pulls SDXL-int8 model from HuggingFace + 2 LoRA adapters from direct HF URLs +// 2. Verifies downloaded files and graph.pbtxt +// 3. Starts serving from the pulled directory (second server launch — no re-download) +// 4. Makes REST requests: base model, individual LoRA, composite LoRA +// 5. Saves generated images to disk for manual inspection +// +// Model directory persists at: /tmp/ovms_test_sdxl_lora/ +// Output images saved to: /tmp/ovms_test_sdxl_lora_output/ +// +// LoRA adapters (all SDXL-compatible, from openvino_notebooks/multilora-image-generation): +// - xray: DoctorDiffusion/doctor-diffusion-s-xray-xl-lora / DD-xray-v1.safetensors (weight 0.8) +// - chalkboard: Norod78/sdxl-chalkboarddrawing-lora / SDXL_ChalkBoardDrawing_LoRA_r8.safetensors (weight 0.45) +// - combo: composite of @xray:0.8+@chalkboard:0.45 +// +// Additional LoRAs available (commented out, can be swapped in): +// - point: alvdansen/the-point / araminta_k_the_point.safetensors (weight 0.6) +// - ukiyoe: KappaNeuro/ukiyo-e-art / Ukiyo-e Art.safetensors (weight 0.8) +// - vector: DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora / DD-vector-v2.safetensors (weight 0.8) +// +// Manual reproduction (run inside docker container): +// # Pull: +// ./bazel-bin/src/ovms --pull --source_model OpenVINO/stable-diffusion-xl-base-1.0-int8-ov --model_repository_path /tmp/ovms_test_sdxl_lora --task image_generation --source_loras "xray=https://huggingface.co/DoctorDiffusion/doctor-diffusion-s-xray-xl-lora/resolve/main/DD-xray-v1.safetensors,chalkboard=https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora/resolve/main/SDXL_ChalkBoardDrawing_LoRA_r8.safetensors,combo=@xray:0.8+@chalkboard:0.45" +// +// # Serve: +// ./bazel-bin/src/ovms --source_model OpenVINO/stable-diffusion-xl-base-1.0-int8-ov --model_repository_path /tmp/ovms_test_sdxl_lora --task image_generation --source_loras "xray=/tmp/ovms_test_sdxl_lora/OpenVINO/stable-diffusion-xl-base-1.0-int8-ov/loras/xray/DD-xray-v1.safetensors,chalkboard=/tmp/ovms_test_sdxl_lora/OpenVINO/stable-diffusion-xl-base-1.0-int8-ov/loras/chalkboard/SDXL_ChalkBoardDrawing_LoRA_r8.safetensors,combo=@xray:0.8+@chalkboard:0.45" --rest_port 8080 +// +// # Generate (curl): +// curl -s http://localhost:8080/v3/images/generations -H "Content-Type: application/json" -d '{"model":"xray","prompt":"xray a castle on a hill","size":"256x256","num_inference_steps":4}' | python3 -c "import sys,json,base64; d=json.load(sys.stdin); open('/tmp/xray.png','wb').write(base64.b64decode(d['data'][0]['b64_json']))" +// curl -s http://localhost:8080/v3/images/generations -H "Content-Type: application/json" -d '{"model":"chalkboard","prompt":"A colorful chalkboard drawing of a castle","size":"256x256","num_inference_steps":4}' | python3 -c "import sys,json,base64; d=json.load(sys.stdin); open('/tmp/chalkboard.png','wb').write(base64.b64decode(d['data'][0]['b64_json']))" +// curl -s http://localhost:8080/v3/images/generations -H "Content-Type: application/json" -d '{"model":"combo","prompt":"xray chalkboard castle","size":"256x256","num_inference_steps":4}' | python3 -c "import sys,json,base64; d=json.load(sys.stdin); open('/tmp/combo.png','wb').write(base64.b64decode(d['data'][0]['b64_json']))" +#ifndef _WIN32 + +// LoRA direct download URLs +static const std::string LORA_XRAY_URL = "https://huggingface.co/DoctorDiffusion/doctor-diffusion-s-xray-xl-lora/resolve/main/DD-xray-v1.safetensors"; +static const std::string LORA_CHALKBOARD_URL = "https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora/resolve/main/SDXL_ChalkBoardDrawing_LoRA_r8.safetensors"; +// static const std::string LORA_POINT_URL = "https://huggingface.co/alvdansen/the-point/resolve/main/araminta_k_the_point.safetensors"; +// static const std::string LORA_UKIYOE_URL = "https://huggingface.co/KappaNeuro/ukiyo-e-art/resolve/main/Ukiyo-e%20Art.safetensors"; +// static const std::string LORA_VECTOR_URL = "https://huggingface.co/DoctorDiffusion/doctor-diffusion-s-controllable-vector-art-xl-lora/resolve/main/DD-vector-v2.safetensors"; + +static const std::string SDXL_MODEL_NAME = "OpenVINO/stable-diffusion-xl-base-1.0-int8-ov"; +static const std::string SDXL_DOWNLOAD_PATH = "/tmp/ovms_test_sdxl_lora"; +static const std::string SDXL_OUTPUT_PATH = "/tmp/ovms_test_sdxl_lora_output"; + +// Helper: extract b64_json from response body and save as PNG file +static void saveGeneratedImage(const std::string& responseBody, const std::string& outputPath) { + // Find b64_json value in JSON response + std::string marker = "\"b64_json\":\""; + auto pos = responseBody.find(marker); + if (pos == std::string::npos) + return; + pos += marker.size(); + auto endPos = responseBody.find("\"", pos); + if (endPos == std::string::npos) + return; + std::string b64 = responseBody.substr(pos, endPos - pos); + + // Decode base64 — simple decoder for test purposes + static const std::string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string decoded; + decoded.reserve(b64.size() * 3 / 4); + std::vector T(256, -1); + for (int i = 0; i < 64; i++) + T[chars[i]] = i; + + int val = 0, valb = -8; + for (unsigned char c : b64) { + if (T[c] == -1) + break; + val = (val << 6) + T[c]; + valb += 6; + if (valb >= 0) { + decoded.push_back(char((val >> valb) & 0xFF)); + valb -= 8; + } + } + + std::ofstream out(outputPath, std::ios::binary); + out.write(decoded.data(), decoded.size()); + std::cout << "Saved generated image (" << decoded.size() << " bytes) to: " << outputPath << std::endl; +} + +TEST(HfPullImageGenWithLora, PullServeAndGenerateWithLoras) { + SKIP_AND_EXIT_IF_NOT_RUNNING_UNSTABLE(); + + ovms::Server& server = ovms::Server::instance(); + std::unique_ptr t; + std::string downloadPath = SDXL_DOWNLOAD_PATH; + std::string modelName = SDXL_MODEL_NAME; + std::string task = "image_generation"; + + // Prepare output directory for generated images + std::filesystem::create_directories(SDXL_OUTPUT_PATH); + + // ==================== PART 1: Pull model + LoRAs ==================== + std::string sourceLoras = + "xray=" + LORA_XRAY_URL + "," + "chalkboard=" + + LORA_CHALKBOARD_URL + "," + "combo=@xray:0.8+@chalkboard:0.45"; + // Alternative LoRAs (swap in as needed): + // "point=" + LORA_POINT_URL + "," + // "ukiyoe=" + LORA_UKIYOE_URL + "," + // "vector=" + LORA_VECTOR_URL + "," + + ::SetUpServerForDownloadWithLoras(t, server, modelName, downloadPath, task, sourceLoras, + EXIT_SUCCESS, 8 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); + + // Server exits after pull — join and reset + server.setShutdownRequest(1); + t->join(); + t.reset(); + server.setShutdownRequest(0); + + // Verify model was downloaded + std::string modelBasePath = ovms::FileSystem::joinPath({downloadPath, "OpenVINO", "stable-diffusion-xl-base-1.0-int8-ov"}); + ASSERT_TRUE(std::filesystem::exists(modelBasePath)) << "Model not downloaded to: " << modelBasePath; + + std::string graphPath = ovms::FileSystem::appendSlash(modelBasePath) + "graph.pbtxt"; + ASSERT_TRUE(std::filesystem::exists(graphPath)) << "graph.pbtxt not found: " << graphPath; + + // Verify graph.pbtxt references all LoRA aliases + std::string graphContents = GetFileContents(graphPath); + EXPECT_NE(graphContents.find("lora_adapters"), std::string::npos) << "graph.pbtxt should contain lora_adapters"; + EXPECT_NE(graphContents.find("xray"), std::string::npos) << "graph.pbtxt should reference xray alias"; + EXPECT_NE(graphContents.find("chalkboard"), std::string::npos) << "graph.pbtxt should reference chalkboard alias"; + EXPECT_NE(graphContents.find("combo"), std::string::npos) << "graph.pbtxt should reference combo composite alias"; + + // Verify LoRA files were downloaded + std::string lorasDir = ovms::FileSystem::joinPath({modelBasePath, "loras"}); + std::string xrayLoraPath = ovms::FileSystem::joinPath({lorasDir, "xray", "DD-xray-v1.safetensors"}); + std::string chalkboardLoraPath = ovms::FileSystem::joinPath({lorasDir, "chalkboard", "SDXL_ChalkBoardDrawing_LoRA_r8.safetensors"}); + ASSERT_TRUE(std::filesystem::exists(xrayLoraPath)) << "X-ray LoRA not found at: " << xrayLoraPath; + ASSERT_TRUE(std::filesystem::exists(chalkboardLoraPath)) << "Chalkboard LoRA not found at: " << chalkboardLoraPath; + + std::cout << "=== PULL COMPLETE ===" << std::endl; + std::cout << "Model path: " << modelBasePath << std::endl; + std::cout << "Graph: " << graphPath << std::endl; + std::cout << "X-ray LoRA: " << xrayLoraPath << std::endl; + std::cout << "Chalkboard LoRA: " << chalkboardLoraPath << std::endl; + + // ==================== PART 2: Serve from pulled directory + generate ==================== + // Re-configure with local file paths for the second server launch + std::string sourceLorasLocal = + "xray=" + xrayLoraPath + "," + "chalkboard=" + + chalkboardLoraPath + "," + "combo=@xray:0.8+@chalkboard:0.45"; + + std::string restPort = "9233"; + ::SetUpServerForDownloadAndStartWithLoras(t, server, + modelName, downloadPath, task, sourceLorasLocal, restPort, 8 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); + + std::cout << "=== SERVER STARTED === REST port: " << restPort << std::endl; + + auto cli = std::make_unique(std::string("http://localhost:") + restPort); + cli->set_read_timeout(600); // SDXL image generation is slow on CPU + + auto healthRes = cli->Get("/v2/health/live"); + ASSERT_TRUE(healthRes) << "Failed to reach server health endpoint"; + ASSERT_EQ(healthRes->status, 200) << "Server not healthy"; + + // --- Generate: base model --- + std::string baseRequestBody = R"({ + "model": ")" + SDXL_MODEL_NAME + + R"(", + "prompt": "a simple red circle on white background", + "size": "256x256", + "num_inference_steps": 4 + })"; + auto baseRes = cli->Post("/v3/images/generations", baseRequestBody, "application/json"); + ASSERT_TRUE(baseRes) << "Base model request failed"; + ASSERT_EQ(baseRes->status, 200) << "Base model failed: " << baseRes->status << " body: " << baseRes->body.substr(0, 500); + EXPECT_NE(baseRes->body.find("\"b64_json\""), std::string::npos); + saveGeneratedImage(baseRes->body, SDXL_OUTPUT_PATH + "/base_model.png"); + + // --- Generate: X-ray LoRA --- + std::string xrayRequestBody = R"({ + "model": "xray", + "prompt": "xray a castle on a hill, detailed architecture", + "size": "256x256", + "num_inference_steps": 4 + })"; + auto xrayRes = cli->Post("/v3/images/generations", xrayRequestBody, "application/json"); + ASSERT_TRUE(xrayRes) << "X-ray LoRA request failed"; + ASSERT_EQ(xrayRes->status, 200) << "X-ray failed: " << xrayRes->status << " body: " << xrayRes->body.substr(0, 500); + EXPECT_NE(xrayRes->body.find("\"b64_json\""), std::string::npos); + saveGeneratedImage(xrayRes->body, SDXL_OUTPUT_PATH + "/xray_lora.png"); + + // --- Generate: Chalkboard LoRA --- + std::string chalkboardRequestBody = R"({ + "model": "chalkboard", + "prompt": "A colorful chalkboard drawing of a castle on a hill", + "size": "256x256", + "num_inference_steps": 4 + })"; + auto chalkboardRes = cli->Post("/v3/images/generations", chalkboardRequestBody, "application/json"); + ASSERT_TRUE(chalkboardRes) << "Chalkboard LoRA request failed"; + ASSERT_EQ(chalkboardRes->status, 200) << "Chalkboard failed: " << chalkboardRes->status << " body: " << chalkboardRes->body.substr(0, 500); + EXPECT_NE(chalkboardRes->body.find("\"b64_json\""), std::string::npos); + saveGeneratedImage(chalkboardRes->body, SDXL_OUTPUT_PATH + "/chalkboard_lora.png"); + + // --- Generate: Composite LoRA (combo = xray:0.8 + chalkboard:0.45) --- + std::string comboRequestBody = R"({ + "model": "combo", + "prompt": "xray A colorful chalkboard drawing of a castle on a hill, detailed architecture", + "size": "256x256", + "num_inference_steps": 4 + })"; + auto comboRes = cli->Post("/v3/images/generations", comboRequestBody, "application/json"); + ASSERT_TRUE(comboRes) << "Composite LoRA request failed"; + ASSERT_EQ(comboRes->status, 200) << "Composite failed: " << comboRes->status << " body: " << comboRes->body.substr(0, 500); + EXPECT_NE(comboRes->body.find("\"b64_json\""), std::string::npos); + saveGeneratedImage(comboRes->body, SDXL_OUTPUT_PATH + "/combo_lora.png"); + + std::cout << "=== ALL IMAGES GENERATED ===" << std::endl; + std::cout << "Output directory: " << SDXL_OUTPUT_PATH << std::endl; + + // Shutdown + server.setShutdownRequest(1); + t->join(); + t.reset(); + server.setShutdownRequest(0); +} +#endif // !_WIN32 diff --git a/src/test/test_utils.cpp b/src/test/test_utils.cpp index a9e22659f9..69a612e403 100644 --- a/src/test/test_utils.cpp +++ b/src/test/test_utils.cpp @@ -794,6 +794,27 @@ void SetUpServerForDownloadWithDraft(std::unique_ptr& t, ovms::Serv EnsureServerModelDownloadFinishedWithTimeout(server, timeoutSeconds); } +void SetUpServerForDownloadWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, int expected_code, int timeoutSeconds) { + server.setShutdownRequest(0); + char* argv[] = {(char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)source_model.c_str(), + (char*)"--model_repository_path", + (char*)download_path.c_str(), + (char*)"--task", + (char*)task.c_str(), + (char*)"--source_loras", + (char*)source_loras.c_str()}; + + int argc = 10; + t.reset(new std::thread([&argc, &argv, &server, expected_code]() { + EXPECT_EQ(expected_code, server.start(argc, argv)); + })); + + EnsureServerModelDownloadFinishedWithTimeout(server, timeoutSeconds); +} + void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, int timeoutSeconds) { server.setShutdownRequest(0); std::string port = "9133"; @@ -816,6 +837,58 @@ void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Serve EnsureServerStartedWithTimeout(server, timeoutSeconds); } +void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& restPort, int timeoutSeconds) { + server.setShutdownRequest(0); + std::string port = "9133"; + randomizeAndEnsureFree(port); + randomizeAndEnsureFree(restPort); + char* argv[] = {(char*)"ovms", + (char*)"--port", + (char*)port.c_str(), + (char*)"--rest_port", + (char*)restPort.c_str(), + (char*)"--source_model", + (char*)source_model.c_str(), + (char*)"--model_repository_path", + (char*)download_path.c_str(), + (char*)"--task", + (char*)task.c_str()}; + + int argc = 11; + t.reset(new std::thread([&argc, &argv, &server]() { + EXPECT_EQ(EXIT_SUCCESS, server.start(argc, argv)); + })); + + EnsureServerStartedWithTimeout(server, timeoutSeconds); +} + +void SetUpServerForDownloadAndStartWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, std::string& restPort, int timeoutSeconds) { + server.setShutdownRequest(0); + std::string port = "9133"; + randomizeAndEnsureFree(port); + randomizeAndEnsureFree(restPort); + char* argv[] = {(char*)"ovms", + (char*)"--port", + (char*)port.c_str(), + (char*)"--rest_port", + (char*)restPort.c_str(), + (char*)"--source_model", + (char*)source_model.c_str(), + (char*)"--model_repository_path", + (char*)download_path.c_str(), + (char*)"--task", + (char*)task.c_str(), + (char*)"--source_loras", + (char*)source_loras.c_str()}; + + int argc = 13; + t.reset(new std::thread([&argc, &argv, &server]() { + EXPECT_EQ(EXIT_SUCCESS, server.start(argc, argv)); + })); + + EnsureServerStartedWithTimeout(server, timeoutSeconds); +} + void SetUpServerForDownloadAndStartGGUF(std::unique_ptr& t, ovms::Server& server, std::string& ggufFilename, std::string& sourceModel, std::string& downloadPath, std::string& task, int timeoutSeconds) { server.setShutdownRequest(0); std::string port = "9133"; diff --git a/src/test/test_utils.hpp b/src/test/test_utils.hpp index 8fb0885cda..9dfa2c38b1 100644 --- a/src/test/test_utils.hpp +++ b/src/test/test_utils.hpp @@ -786,11 +786,21 @@ void SetUpServerForDownloadWithDraft(std::unique_ptr& t, ovms::Serv * --source_model Qwen/Qwen3-8B-GGUF --model_repository_path /models --gguf_filename Qwen3-8B-Q4_K_M.gguf */ void SetUpServerForDownloadAndStartGGUF(std::unique_ptr& t, ovms::Server& server, std::string& ggufFilename, std::string& sourceModel, std::string& downloadPath, std::string& task, int timeoutSeconds = 4 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); +/* + * starts loading OVMS on separate thread but waits until it is shutdowned or model is downloaded and check if model is downloaded in ovms + * --pull --source_model org/model --model_repository_path /models --task image_generation --source_loras alias=org/repo + */ +void SetUpServerForDownloadWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, int expected_code = EXIT_SUCCESS, int timeoutSeconds = 4 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); /* * starts loading OVMS on separate thread but waits until it is shutdowned or model is downloaded and check if model is started in ovms * --source_model OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov --model_repository_path /models */ void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, int timeoutSeconds = SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); +void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& restPort, int timeoutSeconds = SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); +/* + * starts loading OVMS on separate thread with LoRA adapters, waits until model is downloaded and serving + */ +void SetUpServerForDownloadAndStartWithLoras(std::unique_ptr& t, ovms::Server& server, std::string& source_model, std::string& download_path, std::string& task, std::string& source_loras, std::string& restPort, int timeoutSeconds = 4 * SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); /* * starts loading OVMS on separate thread but waits until it is ready */ diff --git a/src/test/text2image_test.cpp b/src/test/text2image_test.cpp index 3d61d3e3cd..db60ea8c1d 100644 --- a/src/test/text2image_test.cpp +++ b/src/test/text2image_test.cpp @@ -1496,5 +1496,485 @@ TEST(Text2ImageTest, ResponseFromOvTensorBatch3) { uint16_t n = 3; testResponseFromOvTensor(n); } +// ===================== LoRA Proto Parsing Tests ===================== + +TEST(ImageGenCalculatorOptionsTest, LoraAdaptersAbsolutePath) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/absolute/path/to/lora.safetensors" } + lora_adapters { alias: "anime" path: "/another/path/weights.safetensors" alpha: 0.5 } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + EXPECT_EQ(imageGenArgs.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(imageGenArgs.loraAdapters[0].path, "/absolute/path/to/lora.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 1.0f); + EXPECT_EQ(imageGenArgs.loraAdapters[1].alias, "anime"); + EXPECT_EQ(imageGenArgs.loraAdapters[1].path, "/another/path/weights.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.5f); +} + +TEST(ImageGenCalculatorOptionsTest, LoraAdaptersRelativePath) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "loras/org/repo/model.safetensors" } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = "/ovms/graph_dir"; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 1); + EXPECT_EQ(imageGenArgs.loraAdapters[0].alias, "pokemon"); + EXPECT_EQ(imageGenArgs.loraAdapters[0].path, "/ovms/graph_dir/loras/org/repo/model.safetensors"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 1.0f); +} + +TEST(ImageGenCalculatorOptionsTest, NoLoraAdaptersProducesEmptyVector) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + } + } + )pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_TRUE(imageGenArgs.loraAdapters.empty()); +} + +TEST(ImageGenCalculatorOptionsTest, CompositeLoraAdaptersFromPbtxt) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 0.7 } + components { adapter_alias: "anime" alpha: 0.5 } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + EXPECT_EQ(imageGenArgs.loraAdapters[0].alias, "pokemon"); + // Composite alpha is applied to adapter (individual was default 1.0) + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 0.7f); + EXPECT_EQ(imageGenArgs.loraAdapters[1].alias, "anime"); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.5f); + ASSERT_EQ(imageGenArgs.compositeLoraAdapters.size(), 1); + auto it = imageGenArgs.compositeLoraAdapters.find("blend"); + ASSERT_NE(it, imageGenArgs.compositeLoraAdapters.end()); + ASSERT_EQ(it->second.size(), 2); + EXPECT_EQ(it->second[0].first, "pokemon"); + EXPECT_FLOAT_EQ(it->second[0].second, 0.7f); + EXPECT_EQ(it->second[1].first, "anime"); + EXPECT_FLOAT_EQ(it->second[1].second, 0.5f); +} + +TEST(ImageGenCalculatorOptionsTest, AlphaOnlyAtIndividualLevelIsValid) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" alpha: 0.8 } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" alpha: 0.6 } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" } + components { adapter_alias: "anime" } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + // Individual alphas preserved (composite has default 1.0) + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 0.8f); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.6f); +} + +TEST(ImageGenCalculatorOptionsTest, AlphaOnlyAtCompositeLevelIsValid) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 0.7 } + components { adapter_alias: "anime" alpha: 0.4 } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + // Composite alpha applied to adapters (individual was default 1.0) + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 0.7f); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.4f); +} + +TEST(ImageGenCalculatorOptionsTest, ExplicitAlpha1AtIndividualLevelAllowsCompositeOverride) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" alpha: 1.0 } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" alpha: 1.0 } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 0.7 } + components { adapter_alias: "anime" alpha: 0.4 } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + // alpha=1.0 is the default, so composite alpha should override + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 0.7f); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.4f); +} + +TEST(ImageGenCalculatorOptionsTest, ExplicitAlpha1AtCompositeLevelKeepsIndividualAlpha) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" alpha: 0.8 } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" alpha: 0.6 } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 1.0 } + components { adapter_alias: "anime" alpha: 1.0 } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + // composite alpha=1.0 is default, so individual alpha is kept + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + auto imageGenArgs = std::get(imageGenArgsOrStatus); + ASSERT_EQ(imageGenArgs.loraAdapters.size(), 2); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[0].alpha, 0.8f); + EXPECT_FLOAT_EQ(imageGenArgs.loraAdapters[1].alpha, 0.6f); +} + +TEST(ImageGenCalculatorOptionsTest, AlphaAtBothLevelsReturnsError) { +#ifdef _WIN32 + const std::string dummyLocation = dummy_model_location; +#else + const std::string dummyLocation = "/ovms/src/test/dummy"; +#endif + std::ostringstream oss; + oss << R"pb( + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")pb" + << dummyLocation; + oss << R"(")"; + oss << R"pb( + lora_adapters { alias: "pokemon" path: "/path/to/pokemon.safetensors" alpha: 0.8 } + lora_adapters { alias: "anime" path: "/path/to/anime.safetensors" } + composite_lora_adapters { + alias: "blend" + components { adapter_alias: "pokemon" alpha: 0.5 } + components { adapter_alias: "anime" alpha: 0.4 } + } + } + } +)pb"; + auto nodePbtxt = oss.str(); + auto node = mediapipe::ParseTextProtoOrDie(nodePbtxt); + const std::string graphPath = ""; + auto nodeOptions = node.node_options(0); + auto imageGenArgsOrStatus = prepareImageGenPipelineArgs(nodeOptions, graphPath); + // Should return error because "pokemon" has alpha at both levels + ASSERT_TRUE(std::holds_alternative(imageGenArgsOrStatus)); + EXPECT_EQ(std::get(imageGenArgsOrStatus).getCode(), ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID); +} + // TODO: // -> test for all unhandled OpenAI fields define what to do - ignore/error imageVariation + +TEST(Text2ImageTest, parseLoraAlphasOverrideValidObject) { + rapidjson::Document doc; + doc.Parse(R"({ + "prompt": "test", + "lora_alphas": { + "pokemon": 0.7, + "anime": 0.4 + } + })"); + auto result = ovms::parseLoraAlphasOverride(doc); + ASSERT_EQ(result.size(), 2); + EXPECT_FLOAT_EQ(result["pokemon"], 0.7f); + EXPECT_FLOAT_EQ(result["anime"], 0.4f); +} + +TEST(Text2ImageTest, parseLoraAlphasOverrideMissingField) { + rapidjson::Document doc; + doc.Parse(R"({"prompt": "test"})"); + auto result = ovms::parseLoraAlphasOverride(doc); + EXPECT_TRUE(result.empty()); +} + +TEST(Text2ImageTest, parseLoraAlphasOverrideNotAnObject) { + rapidjson::Document doc; + doc.Parse(R"({"prompt": "test", "lora_alphas": "invalid"})"); + auto result = ovms::parseLoraAlphasOverride(doc); + EXPECT_TRUE(result.empty()); +} + +TEST(Text2ImageTest, parseLoraAlphasOverrideNonNumericValuesIgnored) { + rapidjson::Document doc; + doc.Parse(R"({ + "prompt": "test", + "lora_alphas": { + "pokemon": 0.7, + "anime": "not_a_number", + "style": true + } + })"); + auto result = ovms::parseLoraAlphasOverride(doc); + ASSERT_EQ(result.size(), 1); + EXPECT_FLOAT_EQ(result["pokemon"], 0.7f); +} + +TEST(Text2ImageTest, parseLoraAlphasOverrideEmptyObject) { + rapidjson::Document doc; + doc.Parse(R"({"prompt": "test", "lora_alphas": {}})"); + auto result = ovms::parseLoraAlphasOverride(doc); + EXPECT_TRUE(result.empty()); +} + +TEST(Text2ImageTest, parseLoraAlphasOverrideNegativeAndZeroAlpha) { + rapidjson::Document doc; + doc.Parse(R"({ + "prompt": "test", + "lora_alphas": { + "pokemon": -0.5, + "anime": 0.0 + } + })"); + auto result = ovms::parseLoraAlphasOverride(doc); + ASSERT_EQ(result.size(), 2); + EXPECT_FLOAT_EQ(result["pokemon"], -0.5f); + EXPECT_FLOAT_EQ(result["anime"], 0.0f); +} + +TEST(Text2ImageTest, validateLoraAlphasRejectedWhenNoDynamicAdapters) { + std::unordered_map loraAlphas = {{"pokemon", 0.5f}}; + auto status = ovms::validateLoraAlphasAllowed(false, loraAlphas); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(std::string(status.message()), ::testing::HasSubstr("lora_alphas is not supported")); +} + +TEST(Text2ImageTest, validateLoraAlphasAllowedWithDynamicAdapters) { + std::unordered_map loraAlphas = {{"pokemon", 0.5f}}; + auto status = ovms::validateLoraAlphasAllowed(true, loraAlphas); + EXPECT_TRUE(status.ok()); +} + +TEST(Text2ImageTest, validateLoraAlphasEmptyPassesWithoutDynamicAdapters) { + std::unordered_map loraAlphas; + auto status = ovms::validateLoraAlphasAllowed(false, loraAlphas); + EXPECT_TRUE(status.ok()); +} + +TEST(Text2ImageTest, getImageGenerationRequestOptionsRejectsLoraAlphasWithoutDynamicAdapters) { + rapidjson::Document doc; + doc.Parse(R"({ + "prompt": "test prompt", + "model": "pokemon", + "lora_alphas": {"pokemon": 0.5} + })"); + auto result = ovms::getImageGenerationRequestOptions(doc, DEFAULTIMAGE_GEN_ARGS, false); + ASSERT_TRUE(std::holds_alternative(result)); + auto& err = std::get(result); + EXPECT_EQ(err.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(std::string(err.message()), ::testing::HasSubstr("lora_alphas is not supported")); +} + +TEST(Text2ImageTest, getImageGenerationRequestOptionsAllowsLoraAlphasWithDynamicAdapters) { + rapidjson::Document doc; + doc.Parse(R"({ + "prompt": "test prompt", + "model": "pokemon", + "lora_alphas": {"pokemon": 0.5} + })"); + auto result = ovms::getImageGenerationRequestOptions(doc, DEFAULTIMAGE_GEN_ARGS, true); + ASSERT_TRUE(std::holds_alternative(result)); +}