diff --git a/src/BUILD b/src/BUILD index 47510cd54d..f857da408f 100644 --- a/src/BUILD +++ b/src/BUILD @@ -363,6 +363,7 @@ ovms_cc_library( srcs = ["cli_parser.cpp"], deps = [ "@com_github_jarro2783_cxxopts//:cxxopts", + "@com_github_tencent_rapidjson//:rapidjson", "libovms_server_settings", "libovms_version", "//src/filesystem:libovmsfilesystem", @@ -373,6 +374,8 @@ ovms_cc_library( "//src/graph_export:t2s_graph_cli_parser", "//src/graph_export:s2t_graph_cli_parser", "//src/graph_export:image_generation_graph_cli_parser", + "//src/pull_module:curl_downloader", + "//src/pull_module:hf_env_vars", ], visibility = ["//visibility:public",], ) @@ -2158,8 +2161,6 @@ cc_binary( linkstatic = True, ) - - cc_binary( name = "optimum-cli", srcs = [ @@ -2273,6 +2274,7 @@ cc_test( "test/tensor_conversion_test.cpp", "test/tensorinfo_test.cpp", "test/tensorutils_test.cpp", + "test/task_determine_test.cpp", "test/test_http_utils.hpp", "test/tfs_rest_parser_binary_inputs_test.cpp", "test/tfs_rest_parser_column_test.cpp", @@ -2467,7 +2469,7 @@ cc_test( "//src:libcustom_node_image_transformation.so", "//src:libcustom_node_add_one.so", "//src:libcustom_node_horizontal_ocr.so", - ], + ] + glob(["test/models_config_json/**"]), deps = [ "optimum-cli", "//src:ovms_lib", diff --git a/src/cli_parser.cpp b/src/cli_parser.cpp index e05109d905..ccb045ad3a 100644 --- a/src/cli_parser.cpp +++ b/src/cli_parser.cpp @@ -15,15 +15,31 @@ //***************************************************************************** #include "cli_parser.hpp" +#include #include #include +#include #include #include +#include #include #include #include +#ifdef _MSC_VER +#pragma warning(push) +// Suppress warning originating from third-party RapidJSON internals on MSVC. +#pragma warning(disable : 6313) +#endif +#include +#include +#include +#ifdef _MSC_VER +#pragma warning(pop) +#endif + #include "capi_frontend/server_settings.hpp" +#include "logging.hpp" #include "graph_export/graph_cli_parser.hpp" #include "graph_export/rerank_graph_cli_parser.hpp" #include "graph_export/embeddings_graph_cli_parser.hpp" @@ -33,6 +49,8 @@ #include "ovms_exit_codes.hpp" #include "filesystem/filesystem.hpp" #include "filesystem/localfilesystem.hpp" +#include "pull_module/hf_env_vars.hpp" +#include "pull_module/curl_downloader.hpp" #include "stringutils.hpp" #include "version.hpp" @@ -40,6 +58,197 @@ namespace ovms { constexpr const char* CONFIG_MANAGEMENT_HELP_GROUP{"config management"}; constexpr const char* API_KEY_ENV_VAR{"API_KEY"}; +constexpr const char* MODEL_CONFIG_FILENAME{"config.json"}; +constexpr const char* MODEL_INDEX_FILENAME{"model_index.json"}; + +namespace { + +const std::map architectureToTask = { + {"BertForSequenceClassification", "rerank"}, + {"BertModel", "embeddings"}, + {"CLIPTextModel", "image_generation"}, + {"FluxTransformer2DModel", "image_generation"}, + {"InternVLChatModel", "text_generation"}, + {"JinaBertModel", "embeddings"}, + {"MPNetModel", "embeddings"}, + {"ParlerTTSForConditionalGeneration", "text2speech"}, + {"Qwen2ForSequenceClassification", "rerank"}, + {"Qwen2Model", "embeddings"}, + {"Qwen3ASRForConditionalGeneration", "speech2text"}, + {"RobertaForSequenceClassification", "rerank"}, + {"RobertaModel", "embeddings"}, + {"SD3Transformer2DModel", "image_generation"}, + {"SeamlessM4TModel", "speech2text"}, + {"SeamlessM4Tv2Model", "speech2text"}, + {"SpeechT5ForTextToSpeech", "text2speech"}, + {"T5EncoderModel", "embeddings"}, + {"UNet2DConditionModel", "image_generation"}, + {"WhisperForConditionalGeneration", "speech2text"}, + {"XLMRobertaForSequenceClassification", "rerank"}, + {"XLMRobertaModel", "embeddings"}, +}; + +// architecture: {default task, {task, pattern}} +const std::map>>> questionableArchitectureTaskKeywords = { + {"Qwen3ForCausalLM", {"text_generation", {{"rerank", "rerank"}, {"embeddings", "embed"}}}}, +}; + +std::string getEnvOrDefault(const char* envName, const std::string& defaultValue = "") { + const char* envValue = std::getenv(envName); + if (envValue == nullptr) { + return defaultValue; + } + return envValue; +} + +std::string ensureTrailingSlash(std::string path) { + if (path.empty() || path.back() == '/') { + return path; + } + path.push_back('/'); + return path; +} + +std::string getTaskForArchitecture(const std::string& architecture) { + const auto exactMatch = architectureToTask.find(architecture); + if (exactMatch != architectureToTask.end()) { + return exactMatch->second; + } + if (architecture == "WhisperForConditionalGeneration" || architecture.rfind("SeamlessM4T", 0) == 0) { + return "speech2text"; + } + if (endsWith(architecture, "ForTextToSpeech")) { + return "text2speech"; + } + if (endsWith(architecture, "ForSequenceClassification")) { + return "rerank"; + } + if (endsWith(architecture, "Transformer2DModel") || architecture == "UNet2DConditionModel" || architecture == "AutoencoderKL") { + return "image_generation"; + } + if (endsWith(architecture, "ForCausalLM") || endsWith(architecture, "ForConditionalGeneration")) { + return "text_generation"; + } + if (endsWith(architecture, "EncoderModel") || endsWith(architecture, "Model")) { + return "embeddings"; + } + return ""; +} + +std::string getTaskForQuestionableArchitecture(const std::string& architecture, const std::string& normalizedModelIdentifier) { + const auto architectureRules = questionableArchitectureTaskKeywords.find(architecture); + if (architectureRules == questionableArchitectureTaskKeywords.end()) { + return ""; + } + const auto& [defaultTask, patternRules] = architectureRules->second; + for (const auto& [task, keyword] : patternRules) { + if (normalizedModelIdentifier.find(keyword) != std::string::npos) { + return task; + } + } + return defaultTask; +} + +std::string determineTaskFromArchitectures(const rapidjson::Value& architecturesNode, const std::string& modelIdentifier) { + if (!architecturesNode.IsArray() || architecturesNode.Empty()) { + throw std::logic_error("config.json does not contain a non-empty architectures array"); + } + const std::string normalizedModelIdentifier = toLower(modelIdentifier); + std::optional resolvedTask; + for (const auto& architecture : architecturesNode.GetArray()) { + if (!architecture.IsString()) { + continue; + } + const std::string architectureName = architecture.GetString(); + std::string task = getTaskForQuestionableArchitecture(architectureName, normalizedModelIdentifier); + if (task.empty() && questionableArchitectureTaskKeywords.find(architectureName) == questionableArchitectureTaskKeywords.end()) { + task = getTaskForArchitecture(architectureName); + } + if (task.empty()) { + continue; + } + if (!resolvedTask.has_value()) { + resolvedTask = task; + continue; + } + if (resolvedTask.value() != task) { + throw std::logic_error("config.json architectures map to multiple default tasks"); + } + } + if (!resolvedTask.has_value()) { + throw std::logic_error("config.json architectures do not map to a supported default task"); + } + return resolvedTask.value(); +} + +std::string determineTaskFromNullArchitectures(const rapidjson::Document& configJson, const std::string& configSourceDescription) { + // Check for special field patterns when architectures is null + if (configJson.HasMember("n_mels")) { + return "text2speech"; + } + throw std::logic_error(configSourceDescription + " has null architectures and does not contain recognized special fields for task detection"); +} + +std::string determineTaskFromConfigStream(std::istream& configStream, const std::string& configSourceDescription, const std::string& modelIdentifier) { + rapidjson::Document configJson; + rapidjson::IStreamWrapper wrapper(configStream); + configJson.ParseStream(wrapper); + if (configJson.HasParseError()) { + throw std::logic_error("failed to parse " + configSourceDescription + ": " + std::string(rapidjson::GetParseError_En(configJson.GetParseError()))); + } + if (!configJson.HasMember("architectures")) { + throw std::logic_error(configSourceDescription + " does not contain architectures field"); + } + const auto& architecturesNode = configJson["architectures"]; + if (architecturesNode.IsNull()) { + return determineTaskFromNullArchitectures(configJson, configSourceDescription); + } + return determineTaskFromArchitectures(architecturesNode, modelIdentifier); +} + +std::string determineTaskFromConfigContents(const std::string& configContents, const std::string& configSourceDescription, const std::string& modelIdentifier) { + rapidjson::Document configJson; + configJson.Parse(configContents.c_str()); + if (configJson.HasParseError()) { + throw std::logic_error("failed to parse " + configSourceDescription + ": " + std::string(rapidjson::GetParseError_En(configJson.GetParseError()))); + } + if (!configJson.HasMember("architectures")) { + throw std::logic_error(configSourceDescription + " does not contain architectures field"); + } + const auto& architecturesNode = configJson["architectures"]; + if (architecturesNode.IsNull()) { + return determineTaskFromNullArchitectures(configJson, configSourceDescription); + } + return determineTaskFromArchitectures(architecturesNode, modelIdentifier); +} + +std::string determineTaskFromModelIndex(std::istream& indexStream, const std::string& indexSourceDescription) { + rapidjson::Document indexJson; + rapidjson::IStreamWrapper wrapper(indexStream); + indexJson.ParseStream(wrapper); + if (indexJson.HasParseError()) { + throw std::logic_error("failed to parse " + indexSourceDescription + ": " + std::string(rapidjson::GetParseError_En(indexJson.GetParseError()))); + } + if (!indexJson.HasMember("_class_name") || !indexJson["_class_name"].IsString()) { + throw std::logic_error(indexSourceDescription + " does not contain a valid _class_name field"); + } + const std::string className = indexJson["_class_name"].GetString(); + if (className.find("StableDiffusion") != std::string::npos || className.find("Flux") != std::string::npos) { + return "image_generation"; + } + throw std::logic_error(indexSourceDescription + " _class_name '" + className + "' does not map to a supported default task"); +} + +bool graphPbtxtExists(const std::string& modelPath) { + const auto graphPath = std::filesystem::path(modelPath) / "graph.pbtxt"; + return std::filesystem::exists(graphPath); +} + +bool hasTaskSpecificParameters(const std::vector& unmatchedOptions) { + return !unmatchedOptions.empty(); +} + +} // namespace std::string getConfigPath(const std::string& configPath) { bool isDir = false; @@ -53,6 +262,65 @@ std::string getConfigPath(const std::string& configPath) { return configPath; } +std::string CLIParser::determineDefaultTaskParameter(const std::optional& modelPath, const std::optional& sourceModel, const std::optional& modelRepositoryPath) { + if (modelPath.has_value() && !modelPath->empty()) { + const auto configPath = std::filesystem::path(*modelPath) / MODEL_CONFIG_FILENAME; + std::ifstream configFile(configPath); + if (configFile.is_open()) { + return determineTaskFromConfigStream(configFile, configPath.string(), *modelPath); + } + const auto indexPath = std::filesystem::path(*modelPath) / MODEL_INDEX_FILENAME; + std::ifstream indexFile(indexPath); + if (indexFile.is_open()) { + return determineTaskFromModelIndex(indexFile, indexPath.string()); + } + throw std::logic_error("failed to open model config file: " + configPath.string() + " or " + indexPath.string()); + } + + if (!sourceModel.has_value() || sourceModel->empty()) { + throw std::logic_error("cannot determine default --task without model_path or source_model"); + } + + if (modelRepositoryPath.has_value() && !modelRepositoryPath->empty()) { + const auto localModelDirectory = std::filesystem::path(*modelRepositoryPath) / *sourceModel; + if (std::filesystem::exists(localModelDirectory)) { + const auto configPath = localModelDirectory / MODEL_CONFIG_FILENAME; + std::ifstream configFile(configPath); + if (configFile.is_open()) { + return determineTaskFromConfigStream(configFile, configPath.string(), *sourceModel); + } + const auto indexPath = localModelDirectory / MODEL_INDEX_FILENAME; + std::ifstream indexFile(indexPath); + if (indexFile.is_open()) { + return determineTaskFromModelIndex(indexFile, indexPath.string()); + } + throw std::logic_error("failed to open model config file: " + configPath.string() + " or " + indexPath.string()); + } + } + + std::string responseBody; + const std::string hfEndpoint = ensureTrailingSlash(getEnvOrDefault(HF_ENDPOINT_ENV_VAR, DEFAULT_HF_ENDPOINT)); + const std::string configUrl = hfEndpoint + *sourceModel + "/resolve/main/" + MODEL_CONFIG_FILENAME; + const auto status = fetchUrlToString(configUrl, getEnvOrDefault(HF_TOKEN_ENV_VAR), responseBody); + if (!status.ok()) { + throw std::logic_error("failed to download model config file from: " + configUrl); + } + return determineTaskFromConfigContents(responseBody, configUrl, *sourceModel); +} + +std::string CLIParser::getEffectiveTaskParameter() const { + if (result->count("task")) { + const auto task = result->operator[]("task").as(); + SPDLOG_DEBUG("Effective task parameter specified by user: {}", task); + return task; + } + if (inferredTaskParameter.has_value()) { + SPDLOG_DEBUG("Effective task parameter using inferred default: {}", inferredTaskParameter.value()); + return inferredTaskParameter.value(); + } + throw std::logic_error("error parsing options - --task parameter wasn't passed"); +} + std::variant> CLIParser::parse(int argc, char** argv) { std::stringstream ss; try { @@ -299,7 +567,7 @@ std::variant> CLIParser::parse(int argc, char* options->add_options("generative task (applies to: pull hf model, single model)") ("task", - "Specifies the generative task for the local model. It should be followed by task specific parameters. Supported tasks: text_generation, embeddings, rerank, image_generation, text2speech, speech2text. It creates the pipeline graph in memory based on the provided task-specific options.", + "Specifies the generative task for the local model. If not provided, default task value is inferred from model config.json architectures. It should be followed by task specific parameters. Supported tasks: text_generation, embeddings, rerank, image_generation, text2speech, speech2text. It creates the pipeline graph in memory based on the provided task-specific options.", cxxopts::value(), "TASK"); configOptions->custom_help(""); @@ -335,12 +603,73 @@ std::variant> CLIParser::parse(int argc, char* result = std::make_unique(options->parse(argc, argv)); + const bool isConfigManagementFlow = + result->count("add_to_config") || result->count("remove_from_config") || result->count("list_models"); + if (!result->count("task") && + !result->count("pull") && + !result->count("source_model") && + result->count("model_path") && + !isConfigManagementFlow && + !result->count("help") && + !result->count("version")) { + const std::optional modelPath = std::make_optional(result->operator[]("model_path").as()); + const auto configPath = std::filesystem::path(*modelPath) / MODEL_CONFIG_FILENAME; + const auto indexPath = std::filesystem::path(*modelPath) / MODEL_INDEX_FILENAME; + if (std::filesystem::exists(configPath) || std::filesystem::exists(indexPath)) { + // Check if task-specific parameters are provided or if graph.pbtxt is missing + bool hasUnmatchedOptions = ::ovms::hasTaskSpecificParameters(result->unmatched()); + bool graphExists = ::ovms::graphPbtxtExists(*modelPath); + + // Infer task if: + // 1. Task-specific parameters are provided (unmatched options), OR + // 2. graph.pbtxt doesn't exist (need to create in-memory graph) + // Otherwise, if graph.pbtxt exists and no task parameters, use the filesystem graph + if (hasUnmatchedOptions || !graphExists) { + try { + inferredTaskParameter = determineDefaultTaskParameter(modelPath, std::nullopt, std::nullopt); + } catch (const std::exception& e) { + SPDLOG_DEBUG("Default task inference skipped for model_path '{}': {}", modelPath.value_or(""), e.what()); + } + } + } + } + // HF pull mode or pull and start mode or starting from local folder with graph created in memory if (isHFPullOrPullAndStart(this->result) || isInMemoryGraphMode(this->result)) { std::vector unmatchedOptions; GraphExportType task; - if (result->count("task")) { - task = stringToEnum(result->operator[]("task").as()); + std::string taskValue; + if (!result->count("task") && !result->count("help") && !result->count("version")) { + const std::optional modelPath = result->count("model_path") ? std::make_optional(result->operator[]("model_path").as()) : std::nullopt; + const std::optional sourceModel = result->count("source_model") ? std::make_optional(result->operator[]("source_model").as()) : std::nullopt; + const std::optional modelRepositoryPath = result->count("model_repository_path") ? std::make_optional(result->operator[]("model_repository_path").as()) : std::nullopt; + + // For source_model (HF pull mode), always infer the task + // For model_path in in-memory graph mode, check if task should be inferred based on parameters and graph.pbtxt + bool shouldInferTask = false; + if (sourceModel.has_value() && !sourceModel->empty()) { + // Always infer task when pulling from HuggingFace + shouldInferTask = true; + } else if (modelPath.has_value() && !modelPath->empty()) { + // For local model_path, infer task if: + // 1. Unmatched options (task-specific parameters) are present, OR + // 2. graph.pbtxt doesn't exist (need to create in-memory graph) + bool hasUnmatchedOptions = ::ovms::hasTaskSpecificParameters(result->unmatched()); + bool graphExists = ::ovms::graphPbtxtExists(*modelPath); + shouldInferTask = hasUnmatchedOptions || !graphExists; + } + + if (shouldInferTask) { + try { + inferredTaskParameter = determineDefaultTaskParameter(modelPath, sourceModel, modelRepositoryPath); + } catch (const std::exception& e) { + SPDLOG_DEBUG("Default task inference skipped for source_model '{}': {}", sourceModel.value_or(""), e.what()); + } + } + } + taskValue = getEffectiveTaskParameter(); + task = stringToEnum(taskValue); + if (task != UNKNOWN_GRAPH) { switch (task) { case TEXT_GENERATION_GRAPH: { GraphCLIParser cliParser; @@ -379,12 +708,12 @@ std::variant> CLIParser::parse(int argc, char* break; } case UNKNOWN_GRAPH: { - ss << "error parsing options - --task parameter unsupported value: " + result->operator[]("task").as(); + ss << "error parsing options - --task parameter unsupported value: " + taskValue; return std::make_pair(OVMS_EX_USAGE, ss.str()); } } } else { - ss << "error parsing options - --task parameter wasn't passed"; + ss << "error parsing options - --task parameter unsupported value: " + taskValue; return std::make_pair(OVMS_EX_USAGE, ss.str()); } @@ -671,11 +1000,14 @@ bool CLIParser::isHFPullOrPullAndStart(const std::unique_ptrcount("pull") || result->count("task")); + return (result->count("pull") || result->count("task") || result->count("source_model")); } bool CLIParser::isInMemoryGraphMode(const std::unique_ptr& result) { - return (result->count("task") && !result->count("source_model") && !result->count("pull")); + if (result->count("source_model") || result->count("pull")) { + return false; + } + return result->count("task") || inferredTaskParameter.has_value(); } void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& hfSettings, const std::string& modelName) { @@ -731,10 +1063,16 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& // When --task is used with --model_path but without --pull/--source_model, // use model_path as the model location (no HF download needed) if (!result->count("pull") && !result->count("source_model") && result->count("model_path")) { - hfSettings.exportSettings.modelPath = result->operator[]("model_path").as(); + const auto configuredModelPath = std::filesystem::path(result->operator[]("model_path").as()); + hfSettings.exportSettings.modelPath = std::filesystem::absolute(configuredModelPath).lexically_normal().string(); + SPDLOG_DEBUG("Using local absolute model path for graph export: {}", hfSettings.exportSettings.modelPath); + } + const std::string taskValue = getEffectiveTaskParameter(); + if (inferredTaskParameter.has_value()) { + SPDLOG_INFO("Identified default task '{}' from model config", inferredTaskParameter.value()); } - if (result->count("task")) { - hfSettings.task = stringToEnum(result->operator[]("task").as()); + if (!taskValue.empty()) { + hfSettings.task = stringToEnum(taskValue); switch (hfSettings.task) { case TEXT_GENERATION_GRAPH: { if (std::holds_alternative(this->graphOptionsParser)) { @@ -785,7 +1123,7 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& break; } case UNKNOWN_GRAPH: { - throw std::logic_error("Error: --task parameter unsupported value: " + result->operator[]("task").as()); + throw std::logic_error("Error: --task parameter unsupported value: " + taskValue); break; } } diff --git a/src/cli_parser.hpp b/src/cli_parser.hpp index 11731d856c..5f4087882e 100644 --- a/src/cli_parser.hpp +++ b/src/cli_parser.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include #include @@ -38,13 +39,16 @@ class CLIParser { std::unique_ptr options; std::unique_ptr result; std::variant graphOptionsParser; + std::optional inferredTaskParameter; public: CLIParser() = default; + static std::string determineDefaultTaskParameter(const std::optional& modelPath, const std::optional& sourceModel, const std::optional& modelRepositoryPath); std::variant> parse(int argc, char** argv); void prepare(ServerSettingsImpl*, ModelsSettingsImpl*); protected: + std::string getEffectiveTaskParameter() const; void prepareServer(ServerSettingsImpl& serverSettings); void prepareModel(ModelsSettingsImpl& modelsSettings, HFSettingsImpl& hfSettings); void prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& hfSettings, const std::string& modelName); diff --git a/src/llm/http_llm_calculator.cc b/src/llm/http_llm_calculator.cc index 85940806a4..9fbfb743b9 100644 --- a/src/llm/http_llm_calculator.cc +++ b/src/llm/http_llm_calculator.cc @@ -68,13 +68,13 @@ class HttpLLMCalculator : public CalculatorBase { absl::Status Close(CalculatorContext* cc) final { OVMS_PROFILE_FUNCTION(); - SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "LLMCalculator [Node: {} ] Close", cc->NodeName()); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "LLMCalculator [Node: {} ] Close", cc->NodeName()); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { OVMS_PROFILE_FUNCTION(); - SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "LLMCalculator [Node: {}] Open start", cc->NodeName()); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "LLMCalculator [Node: {}] Open start", cc->NodeName()); ovms::GenAiServableMap servableMap = cc->InputSidePackets().Tag(LLM_SESSION_SIDE_PACKET_TAG).Get(); auto it = servableMap.find(cc->NodeName()); RET_CHECK(it != servableMap.end()) << "Could not find initialized LLM node named: " << cc->NodeName(); @@ -90,7 +90,7 @@ class HttpLLMCalculator : public CalculatorBase { if (!this->executionContextHolder) { this->executionContext = servable->createExecutionContext(); } - SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "LLMCalculator [Node: {}] Open end", cc->NodeName()); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "LLMCalculator [Node: {}] Open end", cc->NodeName()); return absl::OkStatus(); } absl::Status handleGenerationError(CalculatorContext* cc, const char* errorMessage) { diff --git a/src/pull_module/BUILD b/src/pull_module/BUILD index 1b6c896f94..a7c3198fb2 100644 --- a/src/pull_module/BUILD +++ b/src/pull_module/BUILD @@ -68,6 +68,13 @@ ovms_cc_library( visibility = ["//visibility:public"], ) +ovms_cc_library( + name = "hf_env_vars", + hdrs = ["hf_env_vars.hpp"], + deps = [], + visibility = ["//visibility:public"], +) + ovms_cc_library( name = "gguf_downloader", srcs = ["gguf_downloader.cpp"], @@ -109,6 +116,7 @@ ovms_cc_library( srcs = ["hf_pull_model_module.cpp"], hdrs = ["hf_pull_model_module.hpp"], deps = [ + ":hf_env_vars", ":curl_downloader", ":libgit2", "gguf_downloader", diff --git a/src/pull_module/hf_env_vars.hpp b/src/pull_module/hf_env_vars.hpp new file mode 100644 index 0000000000..77b63f4aae --- /dev/null +++ b/src/pull_module/hf_env_vars.hpp @@ -0,0 +1,22 @@ +//**************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//**************************************************************************** +#pragma once + +namespace ovms { +inline constexpr const char* HF_TOKEN_ENV_VAR = "HF_TOKEN"; +inline constexpr const char* HF_ENDPOINT_ENV_VAR = "HF_ENDPOINT"; +inline constexpr const char* DEFAULT_HF_ENDPOINT = "https://huggingface.co"; +} // namespace ovms diff --git a/src/pull_module/hf_pull_model_module.cpp b/src/pull_module/hf_pull_model_module.cpp index 08c3b97b98..fd7e5fb0c8 100644 --- a/src/pull_module/hf_pull_model_module.cpp +++ b/src/pull_module/hf_pull_model_module.cpp @@ -30,6 +30,7 @@ #include "optimum_export.hpp" #include "curl_downloader.hpp" #include "gguf_downloader.hpp" +#include "hf_env_vars.hpp" #include "../graph_export/graph_export.hpp" #include "../logging.hpp" #include "../module_names.hpp" @@ -293,11 +294,11 @@ const std::string HfPullModelModule::GetProxy() const { } const std::string HfPullModelModule::GetHfToken() const { - return getEnvReturnOrDefaultIfNotSet("HF_TOKEN"); + return getEnvReturnOrDefaultIfNotSet(HF_TOKEN_ENV_VAR); } const std::string HfPullModelModule::GetHfEndpoint() const { - std::string hfEndpoint = getEnvReturnOrDefaultIfNotSet("HF_ENDPOINT", "https://huggingface.co"); + std::string hfEndpoint = getEnvReturnOrDefaultIfNotSet(HF_ENDPOINT_ENV_VAR, DEFAULT_HF_ENDPOINT); if (!endsWith(hfEndpoint, "/")) { hfEndpoint.append("/"); } diff --git a/src/test/models_config_json/Kokoro/config.json b/src/test/models_config_json/Kokoro/config.json new file mode 100644 index 0000000000..3b8fd7490f --- /dev/null +++ b/src/test/models_config_json/Kokoro/config.json @@ -0,0 +1,9 @@ +{ + "return_dict": true, + "output_hidden_states": false, + "dtype": null, + "chunk_size_feed_forward": 0, + "is_encoder_decoder": false, + "architectures": null, + "n_mels": 80 +} diff --git a/src/test/models_config_json/NullArch/config.json b/src/test/models_config_json/NullArch/config.json new file mode 100644 index 0000000000..a06a9cf5fa --- /dev/null +++ b/src/test/models_config_json/NullArch/config.json @@ -0,0 +1,8 @@ +{ + "return_dict": true, + "output_hidden_states": false, + "dtype": null, + "chunk_size_feed_forward": 0, + "is_encoder_decoder": false, + "architectures": null +} diff --git a/src/test/models_config_json/Qwen3-8B/config.json b/src/test/models_config_json/Qwen3-8B/config.json new file mode 100644 index 0000000000..b0eff5ab68 --- /dev/null +++ b/src/test/models_config_json/Qwen3-8B/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Qwen3ForCausalLM"], + "model_type": "qwen3" +} diff --git a/src/test/models_config_json/Qwen3-Embedding-0.6B/config.json b/src/test/models_config_json/Qwen3-Embedding-0.6B/config.json new file mode 100644 index 0000000000..b0eff5ab68 --- /dev/null +++ b/src/test/models_config_json/Qwen3-Embedding-0.6B/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Qwen3ForCausalLM"], + "model_type": "qwen3" +} diff --git a/src/test/models_config_json/Qwen3-Reranker-0.6B/config.json b/src/test/models_config_json/Qwen3-Reranker-0.6B/config.json new file mode 100644 index 0000000000..b0eff5ab68 --- /dev/null +++ b/src/test/models_config_json/Qwen3-Reranker-0.6B/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Qwen3ForCausalLM"], + "model_type": "qwen3" +} diff --git a/src/test/models_config_json/bge/config.json b/src/test/models_config_json/bge/config.json new file mode 100644 index 0000000000..00032627fc --- /dev/null +++ b/src/test/models_config_json/bge/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["BertModel"], + "model_type": "bert", + "hidden_size": 768 +} diff --git a/src/test/models_config_json/bge_reranker/config.json b/src/test/models_config_json/bge_reranker/config.json new file mode 100644 index 0000000000..9b239b2f8b --- /dev/null +++ b/src/test/models_config_json/bge_reranker/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["XLMRobertaForSequenceClassification"], + "model_type": "xlm-roberta" +} diff --git a/src/test/models_config_json/cross_encoder/config.json b/src/test/models_config_json/cross_encoder/config.json new file mode 100644 index 0000000000..554a1ae230 --- /dev/null +++ b/src/test/models_config_json/cross_encoder/config.json @@ -0,0 +1,6 @@ +{ + "architectures": ["BertForSequenceClassification"], + "model_type": "bert", + "hidden_size": 768, + "num_labels": 2 +} diff --git a/src/test/models_config_json/flux/config.json b/src/test/models_config_json/flux/config.json new file mode 100644 index 0000000000..a97b667de1 --- /dev/null +++ b/src/test/models_config_json/flux/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["FluxTransformer2DModel"], + "model_type": "flux", + "hidden_size": 3072 +} diff --git a/src/test/models_config_json/flux_pipeline/model_index.json b/src/test/models_config_json/flux_pipeline/model_index.json new file mode 100644 index 0000000000..d6a329f361 --- /dev/null +++ b/src/test/models_config_json/flux_pipeline/model_index.json @@ -0,0 +1,5 @@ +{ + "_class_name": "FluxPipeline", + "_diffusers_version": "0.34.0", + "_name_or_path": "black-forest-labs/FLUX.1-dev" +} diff --git a/src/test/models_config_json/gemma4/config.json b/src/test/models_config_json/gemma4/config.json new file mode 100644 index 0000000000..aeb089b8f6 --- /dev/null +++ b/src/test/models_config_json/gemma4/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Gemma4ForConditionalGeneration"], + "model_type": "gemma4" +} diff --git a/src/test/models_config_json/invalid_architecture/config.json b/src/test/models_config_json/invalid_architecture/config.json new file mode 100644 index 0000000000..3caa3197ae --- /dev/null +++ b/src/test/models_config_json/invalid_architecture/config.json @@ -0,0 +1,5 @@ +{ + "architectures": [ + "InvalidArchitecture" + ] +} \ No newline at end of file diff --git a/src/test/models_config_json/lfm/config.json b/src/test/models_config_json/lfm/config.json new file mode 100644 index 0000000000..debcb47add --- /dev/null +++ b/src/test/models_config_json/lfm/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Lfm2MoeForCausalLM"], + "model_type": "lfm2_moe" +} diff --git a/src/test/models_config_json/llama/config.json b/src/test/models_config_json/llama/config.json new file mode 100644 index 0000000000..7f9ee05e2b --- /dev/null +++ b/src/test/models_config_json/llama/config.json @@ -0,0 +1,6 @@ +{ + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + "hidden_size": 4096, + "num_hidden_layers": 32 +} diff --git a/src/test/models_config_json/no_architectures/config.json b/src/test/models_config_json/no_architectures/config.json new file mode 100644 index 0000000000..7b634cb35f --- /dev/null +++ b/src/test/models_config_json/no_architectures/config.json @@ -0,0 +1,4 @@ +{ + "model_type": "bert", + "hidden_size": 768 +} diff --git a/src/test/models_config_json/parlertts/config.json b/src/test/models_config_json/parlertts/config.json new file mode 100644 index 0000000000..d54e1378e2 --- /dev/null +++ b/src/test/models_config_json/parlertts/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["ParlerTTSForConditionalGeneration"], + "model_type": "parlertts", + "hidden_size": 768 +} diff --git a/src/test/models_config_json/phi3/config.json b/src/test/models_config_json/phi3/config.json new file mode 100644 index 0000000000..b2c30065d2 --- /dev/null +++ b/src/test/models_config_json/phi3/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["Phi3ForCausalLM"], + "model_type": "phi3", + "hidden_size": 3072 +} diff --git a/src/test/models_config_json/qwen2_embedding/config.json b/src/test/models_config_json/qwen2_embedding/config.json new file mode 100644 index 0000000000..1b9705a034 --- /dev/null +++ b/src/test/models_config_json/qwen2_embedding/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["Qwen2Model"], + "model_type": "qwen2", + "hidden_size": 1024 +} diff --git a/src/test/models_config_json/qwen2_rerank/config.json b/src/test/models_config_json/qwen2_rerank/config.json new file mode 100644 index 0000000000..84f37cb4b1 --- /dev/null +++ b/src/test/models_config_json/qwen2_rerank/config.json @@ -0,0 +1,6 @@ +{ + "architectures": ["Qwen2ForSequenceClassification"], + "model_type": "qwen2", + "hidden_size": 1024, + "num_labels": 2 +} diff --git a/src/test/models_config_json/qwen3/config.json b/src/test/models_config_json/qwen3/config.json new file mode 100644 index 0000000000..690b18578d --- /dev/null +++ b/src/test/models_config_json/qwen3/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["Qwen3ForCausalLM"], + "model_type": "qwen3", + "hidden_size": 1024 +} diff --git a/src/test/models_config_json/qwen3_6/config.json b/src/test/models_config_json/qwen3_6/config.json new file mode 100644 index 0000000000..cc94f1ee62 --- /dev/null +++ b/src/test/models_config_json/qwen3_6/config.json @@ -0,0 +1,5 @@ +{ + "architectures": [ + "Qwen3_5MoeForConditionalGeneration" + ] +} \ No newline at end of file diff --git a/src/test/models_config_json/qwen3_asr/config.json b/src/test/models_config_json/qwen3_asr/config.json new file mode 100644 index 0000000000..303306bb11 --- /dev/null +++ b/src/test/models_config_json/qwen3_asr/config.json @@ -0,0 +1,6 @@ +{ + "architectures": ["Qwen3ASRForConditionalGeneration"], + "model_type": "qwen3_asr", + "return_dict": true, + "output_hidden_states": false +} diff --git a/src/test/models_config_json/qwen3_multi_arch/config.json b/src/test/models_config_json/qwen3_multi_arch/config.json new file mode 100644 index 0000000000..8813114f09 --- /dev/null +++ b/src/test/models_config_json/qwen3_multi_arch/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["Qwen3ASRForConditionalGeneration", "Qwen3ConditionalGeneration"], + "model_type": "qwen3_multi_arch" +} diff --git a/src/test/models_config_json/sdxl/model_index.json b/src/test/models_config_json/sdxl/model_index.json new file mode 100644 index 0000000000..5c1180912c --- /dev/null +++ b/src/test/models_config_json/sdxl/model_index.json @@ -0,0 +1,5 @@ +{ + "_class_name": "StableDiffusionXLPipeline", + "_diffusers_version": "0.34.0", + "_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0" +} diff --git a/src/test/models_config_json/seamlessm4t/config.json b/src/test/models_config_json/seamlessm4t/config.json new file mode 100644 index 0000000000..efbaa39b8e --- /dev/null +++ b/src/test/models_config_json/seamlessm4t/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["SeamlessM4TModel"], + "model_type": "seamless_m4t", + "hidden_size": 512 +} diff --git a/src/test/models_config_json/speecht5/config.json b/src/test/models_config_json/speecht5/config.json new file mode 100644 index 0000000000..91285cfa08 --- /dev/null +++ b/src/test/models_config_json/speecht5/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["SpeechT5ForTextToSpeech"], + "model_type": "speecht5", + "hidden_size": 256 +} diff --git a/src/test/models_config_json/stable_diffusion/config.json b/src/test/models_config_json/stable_diffusion/config.json new file mode 100644 index 0000000000..d1a305bd12 --- /dev/null +++ b/src/test/models_config_json/stable_diffusion/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["UNet2DConditionModel"], + "model_type": "unet", + "in_channels": 9 +} diff --git a/src/test/models_config_json/t5_encoder/config.json b/src/test/models_config_json/t5_encoder/config.json new file mode 100644 index 0000000000..5fd39e33b6 --- /dev/null +++ b/src/test/models_config_json/t5_encoder/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["T5EncoderModel"], + "model_type": "t5", + "d_model": 512 +} diff --git a/src/test/models_config_json/trinity/config.json b/src/test/models_config_json/trinity/config.json new file mode 100644 index 0000000000..25572300d2 --- /dev/null +++ b/src/test/models_config_json/trinity/config.json @@ -0,0 +1,4 @@ +{ + "architectures": ["AfmoeForCausalLM"], + "model_type": "afmoe" +} diff --git a/src/test/models_config_json/whisper/config.json b/src/test/models_config_json/whisper/config.json new file mode 100644 index 0000000000..c88091d6a5 --- /dev/null +++ b/src/test/models_config_json/whisper/config.json @@ -0,0 +1,5 @@ +{ + "architectures": ["WhisperForConditionalGeneration"], + "model_type": "whisper", + "hidden_size": 512 +} diff --git a/src/test/models_config_json/xlm_roberta/config.json b/src/test/models_config_json/xlm_roberta/config.json new file mode 100644 index 0000000000..a153a9be37 --- /dev/null +++ b/src/test/models_config_json/xlm_roberta/config.json @@ -0,0 +1,6 @@ +{ + "architectures": ["XLMRobertaForSequenceClassification"], + "model_type": "xlm-roberta", + "hidden_size": 768, + "num_labels": 2 +} diff --git a/src/test/ovmsconfig_test.cpp b/src/test/ovmsconfig_test.cpp index a4bcad1224..32ac004e14 100644 --- a/src/test/ovmsconfig_test.cpp +++ b/src/test/ovmsconfig_test.cpp @@ -403,7 +403,7 @@ TEST_F(OvmsConfigDeathTest, hfNoTaskParameter) { "/some/path", }; int arg_count = 6; - EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "error parsing options - --task parameter wasn't passed"); + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); } TEST_F(OvmsConfigDeathTest, hfBadTextGraphParameter) { @@ -881,7 +881,35 @@ TEST_F(OvmsConfigDeathTest, hfSourceModelWithoutTask) { "/some/path", }; int arg_count = 5; - EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--source_model should be used combined with --task"); + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); +} + +TEST_F(OvmsConfigDeathTest, hfSourceModelWithoutTaskInvalidArchitectureLocal) { + auto currentPath = std::filesystem::current_path(); + auto repoPath = std::filesystem::weakly_canonical(currentPath / ".." / ".." / "src/test/models_config_json").string(); + char* n_argv[] = { + "ovms", + "--source_model", + "invalid_architecture", + "--model_repository_path", + (char*)repoPath.c_str(), + }; + int arg_count = 5; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); +} + +TEST_F(OvmsConfigDeathTest, hfSourceModelWithoutTaskNoArchitecturesLocal) { + auto currentPath = std::filesystem::current_path(); + auto repoPath = std::filesystem::weakly_canonical(currentPath / ".." / ".." / "src/test/models_config_json").string(); + char* n_argv[] = { + "ovms", + "--source_model", + "no_architectures", + "--model_repository_path", + (char*)repoPath.c_str(), + }; + int arg_count = 5; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); } TEST_F(OvmsConfigDeathTest, hfPullNoRepositoryPath) { @@ -2823,6 +2851,222 @@ TEST(OvmsConfigManipulationTest, positiveDisableModel) { ASSERT_EQ(modelSettings.configPath, configPath); } +// ===================== Inferred Default Task Tests ===================== +// Fixture providing resolved paths to pre-built test model directories. +class OvmsInferredTaskTest : public ::testing::Test { +public: + static std::string resolveTestModelPath(const std::string& modelDirName) { + const std::string relPath = std::string("src/test/models_config_json/") + modelDirName; + auto current = std::filesystem::current_path(); + auto candidate = std::filesystem::weakly_canonical(current / ".." / ".." / relPath); + if (std::filesystem::exists(candidate)) + return candidate.string(); + auto search = current; + while (search != search.parent_path()) { + auto c = search / relPath; + if (std::filesystem::exists(c)) + return c.string(); + search = search.parent_path(); + } + return candidate.string(); + } + + static std::string resolveTestModelsRepoPath() { + const std::string relPath = "src/test/models_config_json"; + auto current = std::filesystem::current_path(); + auto candidate = std::filesystem::weakly_canonical(current / ".." / ".." / relPath); + if (std::filesystem::exists(candidate)) + return candidate.string(); + auto search = current; + while (search != search.parent_path()) { + auto c = search / relPath; + if (std::filesystem::exists(c)) + return c.string(); + search = search.parent_path(); + } + return candidate.string(); + } +}; + +// Scenario 1: --source_model with no explicit --task and a locally available repo copy. +// The task must be detected from the model config.json and the server must start. +TEST_F(OvmsInferredTaskTest, positiveSourceModelInferTaskFromLocalRepo) { + const std::string repoPath = resolveTestModelsRepoPath(); + const std::filesystem::path llamaConfig = std::filesystem::path(repoPath) / "llama" / "config.json"; + if (!std::filesystem::exists(llamaConfig)) { + FAIL() << "Test prerequisite missing: " << llamaConfig.string(); + } + const std::string sourceModel = "llama"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--source_model", + (char*)sourceModel.c_str(), + (char*)"--model_repository_path", + (char*)repoPath.c_str(), + (char*)"--rest_port", + (char*)"8080", + }; + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + ASSERT_EQ(config.getServerSettings().hfSettings.task, ovms::TEXT_GENERATION_GRAPH); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::HF_PULL_AND_START_MODE); +} + +// Scenario 2: --model_path points to a directory that has config.json but NO graph.pbtxt. +// The task must be inferred and the server must start in IN_MEMORY_GRAPH_MODE. +TEST_F(OvmsInferredTaskTest, positiveModelPathNoGraphPbtxtInferTask) { + const std::string modelPath = resolveTestModelPath("llama"); + const std::filesystem::path configJson = std::filesystem::path(modelPath) / "config.json"; + if (!std::filesystem::exists(configJson)) { + FAIL() << "Test prerequisite missing: " << configJson.string(); + } + // Verify the test fixture truly has no graph.pbtxt so the scenario is meaningful. + ASSERT_FALSE(std::filesystem::exists(std::filesystem::path(modelPath) / "graph.pbtxt")) + << "Unexpected graph.pbtxt in test model dir " << modelPath; + char* n_argv[] = { + (char*)"ovms", + (char*)"--model_path", + (char*)modelPath.c_str(), + (char*)"--model_name", + (char*)"llama", + (char*)"--rest_port", + (char*)"8080", + }; + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + ASSERT_EQ(config.getServerSettings().hfSettings.task, ovms::TEXT_GENERATION_GRAPH); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::IN_MEMORY_GRAPH_MODE); +} + +// Scenario 3: Questionable architecture requires additional model naming rules. +// source_model with "embed" should infer embeddings for Qwen3ForCausalLM. +TEST_F(OvmsInferredTaskTest, positiveSourceModelInferEmbeddingsForQuestionableArchitecture) { + const std::string repoPath = resolveTestModelsRepoPath(); + const std::string sourceModel = "Qwen3-Embedding-0.6B"; + const std::filesystem::path configJson = std::filesystem::path(repoPath) / sourceModel / "config.json"; + if (!std::filesystem::exists(configJson)) { + FAIL() << "Test prerequisite missing: " << configJson.string(); + } + char* n_argv[] = { + (char*)"ovms", + (char*)"--source_model", + (char*)sourceModel.c_str(), + (char*)"--model_repository_path", + (char*)repoPath.c_str(), + (char*)"--rest_port", + (char*)"8080", + }; + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + ASSERT_EQ(config.getServerSettings().hfSettings.task, ovms::EMBEDDINGS_GRAPH); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::HF_PULL_AND_START_MODE); +} + +// Scenario 4: model_path with "rerank" in path should infer rerank for Qwen3ForCausalLM. +TEST_F(OvmsInferredTaskTest, positiveModelPathInferRerankForQuestionableArchitecture) { + const std::string modelPath = resolveTestModelPath("Qwen3-Reranker-0.6B"); + const std::filesystem::path configJson = std::filesystem::path(modelPath) / "config.json"; + if (!std::filesystem::exists(configJson)) { + FAIL() << "Test prerequisite missing: " << configJson.string(); + } + ASSERT_FALSE(std::filesystem::exists(std::filesystem::path(modelPath) / "graph.pbtxt")) + << "Unexpected graph.pbtxt in test model dir " << modelPath; + char* n_argv[] = { + (char*)"ovms", + (char*)"--model_path", + (char*)modelPath.c_str(), + (char*)"--model_name", + (char*)"qwen3-reranker", + (char*)"--rest_port", + (char*)"8080", + }; + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + ASSERT_EQ(config.getServerSettings().hfSettings.task, ovms::RERANK_GRAPH); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::IN_MEMORY_GRAPH_MODE); +} + +// Scenario 5: Questionable architecture without identifying keywords must not infer task. +TEST_F(OvmsConfigDeathTest, negativeSourceModelQuestionableArchitectureWithoutPattern) { + auto currentPath = std::filesystem::current_path(); + auto repoPath = std::filesystem::weakly_canonical(currentPath / ".." / ".." / "src/test/models_config_json").string(); + const std::string sourceModel = "Qwen3-8B"; + char* n_argv[] = { + "ovms", + "--source_model", + (char*)sourceModel.c_str(), + "--model_repository_path", + (char*)repoPath.c_str(), + }; + int arg_count = 5; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); +} + +// Scenario 6: Null architectures with n_mels field should infer text2speech task. +TEST_F(OvmsInferredTaskTest, positiveSourceModelInferText2SpeechForNullArchitecturesWithNMels) { + const std::string repoPath = resolveTestModelsRepoPath(); + const std::string sourceModel = "Kokoro"; + const std::filesystem::path configJson = std::filesystem::path(repoPath) / sourceModel / "config.json"; + if (!std::filesystem::exists(configJson)) { + FAIL() << "Test prerequisite missing: " << configJson.string(); + } + char* n_argv[] = { + (char*)"ovms", + (char*)"--source_model", + (char*)sourceModel.c_str(), + (char*)"--model_repository_path", + (char*)repoPath.c_str(), + (char*)"--rest_port", + (char*)"8080", + }; + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + ASSERT_EQ(config.getServerSettings().hfSettings.task, ovms::TEXT_TO_SPEECH_GRAPH); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::HF_PULL_AND_START_MODE); +} + +// Scenario 6b: Null architectures without special fields should fail. +TEST_F(OvmsConfigDeathTest, negativeSourceModelNullArchitecturesWithoutSpecialFields) { + auto currentPath = std::filesystem::current_path(); + auto repoPath = std::filesystem::weakly_canonical(currentPath / ".." / ".." / "src/test/models_config_json").string(); + const std::string sourceModel = "NullArch"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--source_model", + (char*)sourceModel.c_str(), + (char*)"--model_repository_path", + (char*)repoPath.c_str(), + }; + int arg_count = 5; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "--task parameter wasn't passed"); +} + +// Scenario 7: --model_path (LLM/text_generation model) with no explicit --task but with +// an embeddings-specific parameter (--pooling). Task is inferred as text_generation and +// --pooling is not a recognised text_generation option, so parsing must fail. +TEST_F(OvmsConfigDeathTest, negativeModelPathInferredTaskWithMismatchedParam) { + const std::string modelPath = OvmsInferredTaskTest::resolveTestModelPath("llama"); + if (!std::filesystem::exists(std::filesystem::path(modelPath) / "config.json")) { + FAIL() << "Test prerequisite missing: " << modelPath << "/config.json"; + } + char* n_argv[] = { + (char*)"ovms", + (char*)"--model_path", (char*)modelPath.c_str(), + (char*)"--model_name", (char*)"llama", + (char*)"--rest_port", (char*)"8080", + (char*)"--pooling", (char*)"LAST", // embeddings-only param, invalid for text_generation + }; + int arg_count = 9; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), + ::testing::ExitedWithCode(OVMS_EX_USAGE), + "task: text_generation - error parsing options - unmatched arguments"); +} + TEST(OvmsGraphCliParserTest, invalidToolParserNameThrowsInvalidArgument) { ovms::HFSettingsImpl hfSettings; ovms::GraphCLIParser parser; diff --git a/src/test/task_determine_test.cpp b/src/test/task_determine_test.cpp new file mode 100644 index 0000000000..e69837bd7c --- /dev/null +++ b/src/test/task_determine_test.cpp @@ -0,0 +1,116 @@ +//***************************************************************************** +// Copyright 2020-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include +#include +#include +#include + +#include + +#include "../cli_parser.hpp" + +class CLIParserDetermineTaskTest : public ::testing::TestWithParam> { +public: + static const std::vector> modelTaskMapping; + + static std::string getModelPath(const std::string& modelName) { + return "src/test/models_config_json/" + modelName; + } +}; + +const std::vector> CLIParserDetermineTaskTest::modelTaskMapping = { + {"llama", "text_generation"}, + {"qwen3", "text_generation"}, + {"phi3", "text_generation"}, + {"bge", "embeddings"}, + {"t5_encoder", "embeddings"}, + {"qwen2_embedding", "embeddings"}, + {"cross_encoder", "rerank"}, + {"xlm_roberta", "rerank"}, + {"qwen2_rerank", "rerank"}, + {"stable_diffusion", "image_generation"}, + {"flux", "image_generation"}, + {"speecht5", "text2speech"}, + {"parlertts", "text2speech"}, + {"whisper", "speech2text"}, + {"seamlessm4t", "speech2text"}, + {"qwen3_6", "text_generation"}, + {"qwen3_asr", "speech2text"}, + {"lfm", "text_generation"}, + {"trinity", "text_generation"}, + {"gemma4", "text_generation"}, + {"bge_reranker", "rerank"}, + {"sdxl", "image_generation"}, + {"flux_pipeline", "image_generation"}, + {"qwen3_multi_arch", "speech2text"}, + {"invalid_architecture", ""} // This model has an unsupported architecture and should throw an exception +}; + +INSTANTIATE_TEST_SUITE_P( + DetermineTaskFromConfigStream, + CLIParserDetermineTaskTest, + ::testing::ValuesIn(CLIParserDetermineTaskTest::modelTaskMapping), + [](const ::testing::TestParamInfo>& info) { + return info.param.first; + }); + +TEST_P(CLIParserDetermineTaskTest, DetermineTaskFromConfigStream) { + auto [modelName, expectedTask] = GetParam(); + std::string modelDirName = getModelPath(modelName); + + // Test executable is typically two levels below repository root in Bazel outputs. + auto currentPath = std::filesystem::current_path(); + std::filesystem::path modelPath = std::filesystem::weakly_canonical(currentPath / ".." / ".." / modelDirName); + + if (!std::filesystem::exists(modelPath)) { + auto searchPath = currentPath; + while (searchPath != searchPath.parent_path()) { + std::filesystem::path candidate = searchPath / modelDirName; + if (std::filesystem::exists(candidate)) { + modelPath = candidate; + break; + } + searchPath = searchPath.parent_path(); + } + } + + ASSERT_TRUE(std::filesystem::exists(modelPath)) + << "Model directory not found: " << modelDirName + << " (tried: " << (currentPath / ".." / ".." / modelDirName).string() + << ", current_path: " << currentPath.string() << ")"; + ASSERT_TRUE(std::filesystem::exists(modelPath / "config.json") || std::filesystem::exists(modelPath / "model_index.json")) + << "Neither config.json nor model_index.json found in: " << modelPath.string(); + + if (expectedTask.empty()) { + EXPECT_THROW( + ovms::CLIParser::determineDefaultTaskParameter( + std::make_optional(modelPath.string()), + std::nullopt, + std::nullopt), + std::logic_error); + return; + } + + std::string result = ovms::CLIParser::determineDefaultTaskParameter( + std::make_optional(modelPath.string()), + std::nullopt, + std::nullopt); + + EXPECT_EQ(result, expectedTask) + << "Model: " << modelName << ", Expected: " << expectedTask << ", Got: " << result; +}