diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index dfafa9cf8116e..e7ba23587f2ad 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -1,29 +1,3 @@ -# llava (legacy) - -add_library(llava OBJECT - llava.cpp - llava.h - clip.cpp - clip.h - ) - -target_link_libraries(llava PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - -target_include_directories(llava PUBLIC .) -target_include_directories(llava PUBLIC ../..) -target_include_directories(llava PUBLIC ../../common) - -target_compile_features(llava PRIVATE cxx_std_17) - -add_library(llava_static STATIC $) -if (BUILD_SHARED_LIBS) - set_target_properties(llava PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(llava PRIVATE LLAMA_SHARED LLAMA_BUILD) - add_library(llava_shared SHARED $) - target_link_libraries(llava_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - install(TARGETS llava_shared LIBRARY) -endif() - # mtmd add_library(mtmd OBJECT @@ -53,12 +27,10 @@ if (BUILD_SHARED_LIBS) endif() if (NOT MSVC) - target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h endif() if(TARGET BUILD_INFO) - add_dependencies(llava BUILD_INFO) add_dependencies(mtmd BUILD_INFO) endif() @@ -73,10 +45,3 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-llava-clip-quantize-cli) -add_executable(${TARGET} clip-quantize-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/mtmd/README-quantize.md b/tools/mtmd/README-quantize.md deleted file mode 100644 index b931513ab7203..0000000000000 --- a/tools/mtmd/README-quantize.md +++ /dev/null @@ -1,44 +0,0 @@ -# Quantizing CLIP Visual Projector - -This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance. - -## Usage - -To quantize a CLIP visual projector model, use the following command: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf -``` - -After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc). - -### Arguments - -- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format. -- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved. -- ``: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`. - -### Quantization Types - -The following quantization types are supported, based on the `enum ggml_type` definition: - -- `2` - `q4_0`: 4-bit quantization with a single scale value. -- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block. -- `6` - `q5_0`: 5-bit quantization with a single scale value. -- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block. -- `8` - `q8_0`: 8-bit quantization with a single scale value. - -### Example - -To quantize a model using the `q4_0` quantization type, you would run: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2 -``` - -This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method. - -## Notes - -- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements. -- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments. diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index ab258ea174e84..ef31d1957cdab 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -41,8 +41,8 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta Multimodal projector (`mmproj`) files are specific to each model architecture. -For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file: -- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support +For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` flag to get the `mmproj` file: +- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) ; See the guide [here](../../docs/multimodal/gemma3.md) - Note: 1B variant does not have vision support - SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint @@ -52,6 +52,8 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla For older models, please refer to the relevant guide for instructions on how to obtain or create them: +NOTE: conversion scripts are located under `tools/mtmd/legacy-models` + - [LLaVA](../../docs/multimodal/llava.md) - [MobileVLM](../../docs/multimodal/MobileVLM.md) - [GLM-Edge](../../docs/multimodal/glmedge.md) @@ -59,4 +61,3 @@ For older models, please refer to the relevant guide for instructions on how to - [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md) - [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md) - [IBM Granite Vision](../../docs/multimodal/granitevision.md) -- [Google Gemma 3](../../docs/multimodal/gemma3.md) diff --git a/tools/mtmd/android/adb_run.sh b/tools/mtmd/android/adb_run.sh deleted file mode 100755 index a24d6787d9a05..0000000000000 --- a/tools/mtmd/android/adb_run.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -model_dir="/Users/cxt/model/llm/mobileVLM/MobileVLM-1.7B_processed" -projector_name="mmproj-model-f16.gguf" -llama_name="ggml-model-q4_k.gguf" -img_dir="/Users/cxt/model/llm" -img_name="demo.jpg" -prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" -# img_name="cat.jpeg" -# prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" - -program_dir="build_64/bin" -binName="llama-mtmd-cli" -n_threads=4 - - -deviceDir="/data/local/tmp" -saveDir="output" -if [ ! -d ${saveDir} ]; then - mkdir ${saveDir} -fi - - -function android_run() { - # # copy resource into device - # adb push ${model_dir}/${projector_name} ${deviceDir}/${projector_name} - # adb push ${model_dir}/${llama_name} ${deviceDir}/${llama_name} - adb push ${img_dir}/${img_name} ${deviceDir}/${img_name} - # copy program into device - adb push ${program_dir}/${binName} ${deviceDir}/${binName} - adb shell "chmod 0777 ${deviceDir}/${binName}" - - # run - adb shell "echo cd ${deviceDir} ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - > ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt" - adb shell "cd ${deviceDir}; pwd; ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - >> ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt 2>&1" - adb pull ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt ${saveDir} -} - -android_run - -echo "android_run is Done!" diff --git a/tools/mtmd/android/build_64.sh b/tools/mtmd/android/build_64.sh deleted file mode 100755 index 71b6fd3f719cd..0000000000000 --- a/tools/mtmd/android/build_64.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -cmake ../../../../ \ --DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ --DCMAKE_BUILD_TYPE=Release \ --DANDROID_ABI="arm64-v8a" \ --DANDROID_PLATFORM=android-23 $1 - -make -j4 diff --git a/tools/mtmd/clip-quantize-cli.cpp b/tools/mtmd/clip-quantize-cli.cpp deleted file mode 100644 index 566506954edeb..0000000000000 --- a/tools/mtmd/clip-quantize-cli.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "arg.h" -#include "base64.hpp" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -static void print_usage(int argc, char ** argv) { - (void) argc; - - fprintf(stderr, "usage: %s /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 6 - q5_0\n"); - fprintf(stderr, " type = 7 - q5_1\n"); - fprintf(stderr, " type = 8 - q8_0\n"); -} - -int main(int argc, char ** argv) { - if (argc != 4) { - print_usage(argc, argv); - return 1; - } - - const std::string fname_inp = argv[1]; - const std::string fname_out = argv[2]; - - const int itype = atoi(argv[3]); - - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_quantize_us = 0; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!clip_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) { - fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); - return 1; - } - - t_quantize_us = ggml_time_us() - t_start_us; - } - - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - printf("\n"); - printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); - } - - return 0; -} diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3f11c301a7212..c779c3311635a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3576,141 +3576,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return true; } -bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { - assert(itype < GGML_TYPE_COUNT); - ggml_type type = static_cast(itype); - - auto * ctx_clip = clip_init(fname_inp, clip_context_params{ - /* use_gpu */ false, - /* verbosity */ GGML_LOG_LEVEL_ERROR, - }); - - const auto & ctx_src = ctx_clip->ctx_gguf.get(); - const auto & ctx_data = ctx_clip->ctx_data.get(); - - auto * ctx_out = gguf_init_empty(); - gguf_set_kv(ctx_out, ctx_src); - gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out, "general.file_type", itype); - - auto fout = std::ofstream(fname_out, std::ios::binary); - - const int n_tensors = gguf_get_n_tensors(ctx_src); - - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_src, i); - ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - gguf_add_tensor(ctx_out, cur); - } - - const size_t meta_size = gguf_get_meta_size(ctx_out); - for (size_t i = 0; i < meta_size; ++i) { - fout.put(0); - } - - // regexes of tensor names to be quantized - const std::vector k_names = { - ".*weight", - }; - - std::vector work(512); - std::vector conv_buf(512); - size_t total_size_org = 0; - size_t total_size_new = 0; - - for (int i = 0; i < n_tensors; ++i) { - const std::string name = gguf_get_tensor_name(ctx_src, i); - ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str()); - - enum ggml_type new_type; - void * new_data; - size_t new_size; - - bool quantize = false; - for (const auto & s : k_names) { - if (std::regex_match(name, std::regex(s))) { - quantize = true; - break; - } - } - - // quantize only 2D tensors and bigger than block size - quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); - - if (quantize) { - new_type = type; - if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) { - new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type - // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); - } - const size_t n_elms = ggml_nelements(cur); - float * f32_data; - - switch (cur->type) { - case GGML_TYPE_F32: - f32_data = (float *)cur->data; - break; - case GGML_TYPE_F16: - if (conv_buf.size() < n_elms) { - conv_buf.resize(n_elms); - } - for (size_t j = 0; j < n_elms; ++j) { - conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]); - } - f32_data = (float *)conv_buf.data(); - break; - default: - LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__); - gguf_free(ctx_out); - return false; - } - - if (work.size() < n_elms * 4) { - work.resize(n_elms * 4); - } - new_data = work.data(); - - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr); - } else { - new_type = cur->type; - new_data = cur->data; - new_size = ggml_nbytes(cur); - } - const size_t orig_size = ggml_nbytes(cur); - total_size_org += orig_size; - total_size_new += new_size; - gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data); - fout.write((const char *)new_data, new_size); - size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; - for (size_t j = 0; j < pad; ++j) { - fout.put(0); - } - - LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize, - orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); - } - - // go back to beginning of file and write the updated metadata - fout.seekp(0, std::ios::beg); - std::vector meta(meta_size); - gguf_get_meta_data(ctx_out, meta.data()); - fout.write((const char *)meta.data(), meta_size); - - fout.close(); - - clip_free(ctx_clip); - gguf_free(ctx_out); - - { - LOG_INF("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); - LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); - } - - return true; -} - int clip_n_mmproj_embd(const struct clip_ctx * ctx) { switch (ctx->proj_type) { case PROJECTOR_TYPE_LDP: diff --git a/tools/mtmd/convert_image_encoder_to_gguf.py b/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py similarity index 100% rename from tools/mtmd/convert_image_encoder_to_gguf.py rename to tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py diff --git a/tools/mtmd/glmedge-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py similarity index 100% rename from tools/mtmd/glmedge-convert-image-encoder-to-gguf.py rename to tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py diff --git a/tools/mtmd/glmedge-surgery.py b/tools/mtmd/legacy-models/glmedge-surgery.py similarity index 100% rename from tools/mtmd/glmedge-surgery.py rename to tools/mtmd/legacy-models/glmedge-surgery.py diff --git a/tools/mtmd/llava_surgery.py b/tools/mtmd/legacy-models/llava_surgery.py similarity index 100% rename from tools/mtmd/llava_surgery.py rename to tools/mtmd/legacy-models/llava_surgery.py diff --git a/tools/mtmd/llava_surgery_v2.py b/tools/mtmd/legacy-models/llava_surgery_v2.py similarity index 100% rename from tools/mtmd/llava_surgery_v2.py rename to tools/mtmd/legacy-models/llava_surgery_v2.py diff --git a/tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py similarity index 100% rename from tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py rename to tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py diff --git a/tools/mtmd/minicpmv-surgery.py b/tools/mtmd/legacy-models/minicpmv-surgery.py similarity index 100% rename from tools/mtmd/minicpmv-surgery.py rename to tools/mtmd/legacy-models/minicpmv-surgery.py diff --git a/tools/mtmd/llava.cpp b/tools/mtmd/llava.cpp deleted file mode 100644 index ebef8b3c1eab6..0000000000000 --- a/tools/mtmd/llava.cpp +++ /dev/null @@ -1,591 +0,0 @@ -#include "clip.h" -#include "llava.h" - -#include "llama.h" -#include "ggml-cpp.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(LLAVA_LOG_OFF) -# define LOG_INF(...) -# define LOG_WRN(...) -# define LOG_ERR(...) -# define LOG_DBG(...) -#else // defined(LLAVA_LOG_OFF) -# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#endif // defined(LLAVA_LOG_OFF) - -// RGB uint8 image -struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... -struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - -struct clip_image_grid_shape { - int first; - int second; -}; - -// convenience cpp wrapper -struct clip_image_f32_batch_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_f32_batch_ptr; - -struct clip_image_size_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_size_ptr; - -/** - * Selects the best resolution from a list of possible resolutions based on the original size. - * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). - */ -static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; - } - } - - return best_fit; -} - -/** - * @brief Get the anyres image grid shape object - * - * @param image_size - * @param grid_pinpoints - * @param image_patch_size - * @return - */ -static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair & image_size, const std::vector> & grid_pinpoints, int image_patch_size) { - /** - Conversion from gguf flat array to vector: - std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - */ - auto best_resolution = select_best_resolution(image_size, grid_pinpoints); - return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; -} - -// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) -static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) { - struct { - struct ggml_context * ctx; - } model; - - const int32_t image_size = clip_get_image_size(ctx_clip); - const int32_t patch_size = clip_get_patch_size(ctx_clip); - - int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) - - int num_patches_width = grid_shape.first; // grid 1-4 - int num_patches_height = grid_shape.second; // grid 1-4 - - const size_t num_images = num_patches_width * num_patches_height + 1; - - // TODO: size calculation is not calculated - it's only tens of MB - size_t ctx_size = 0; - - { - ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features - ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); - } - - struct ggml_init_params params { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API - }; - - // Python reference code for full unpad: - /* - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1) - ), dim=-1) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - */ - // We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval. - // In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet. - // Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them. - // Once all images are processed to prepended the base_image_features without any changes. - - // Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling)) - /* - image_feature = image_feature.view(2, 2, 24, 24, 4096) - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.view(2, 24, 2, 24, 4096) - image_feature = image_feature.flatten(0, 3) - - // Reshape to 4D tensor by merging the last two dimensions - image_feature = image_feature.view(2, 2, 24, 24*4096) - image_feature = image_feature.permute(0, 2, 1, 3).contiguous() - image_feature = image_feature.view(-1, 4096) - */ - - model.ctx = ggml_init(params); - - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4 - // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); - // fill it with the image embeddings, ignoring the base - for (size_t i = 1; i < num_images; i++) { - size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); - memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); - } - - struct ggml_cgraph * gf = ggml_new_graph(model.ctx); - size_t size_ele = ggml_type_size(GGML_TYPE_F32); - - struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features, - num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - num_patches_per_side, - num_patches_width, - num_patches_height, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0); - // ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false); - struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); - /** - At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) - ), dim=-1) - * - */ - - // ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false); - struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); - // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); - ggml_build_forward_expand(gf, flatten); - - ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; - GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend"); - ggml_backend_graph_compute(backend.get(), gf); - - struct ggml_tensor* result = ggml_graph_node(gf, -1); - - memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context - // append without newline tokens (default behavior in llava_arch when not using unpad ): - memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches - *n_img_pos_out = static_cast(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input)); - - // Debug: Test single segments - // Current findings: sending base image, sending a segment embedding all works similar to python - // However, permuted embeddings do not work yet (stride issue?) - // memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context - // memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context - // *n_img_pos_out=576; - - ggml_free(model.ctx); - return true; -} - -static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { - int width = image->nx; - int height = image->ny; - int num_patches = (height / patch_size) * (width / patch_size); - clip_image_f32 * patch = clip_image_f32_init(); - patch->nx = patch_size * num_patches; - patch->ny = patch_size; - patch->buf.resize(3 * patch->nx * patch->ny); - - int patch_index = 0; - - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - for (int pi = 0; pi < patch_size; ++pi) { - for (int pj = 0; pj < patch_size; ++pj) { - int input_index = ((i + pi) * width + (j + pj)) * 3; - int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3; - patch->buf[output_index] = image->buf[input_index]; - patch->buf[output_index+1] = image->buf[input_index+1]; - patch->buf[output_index+2] = image->buf[input_index+2]; - } - } - patch_index++; - } - } - return patch; -} - -static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { - // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); - if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { - LOG_ERR("%s: unable to preprocess image\n", __func__); - return false; - } - - const int64_t t_img_enc_start_us = ggml_time_us(); - - const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - - const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); - - if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - clip_image_size load_image_size; - - for (size_t i = 0; i < n_imgs; i++) { - const int64_t t_img_enc_step_start_us = ggml_time_us(); - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - int patch_size = 14; - load_image_size.width = nx; - load_image_size.height = ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - - bool encoded = false; - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); - } - else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); - } - - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - int n_img_pos_out = 0; - for (size_t i = 0; i < image_embd_v.size(); i++) { - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - std::memcpy( - image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), - image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res); - } - *n_img_pos = n_img_pos_out; - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - load_image_size.width = img->nx; - load_image_size.height = img->ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); - } - else if (clip_is_glm(ctx_clip)){ - struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); - load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); - clip_add_load_image_size(ctx_clip, load_image_size); - - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); - int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); - *n_img_pos = (pos * pos + 2); - if (!encoded){ - LOG_ERR("Unable to encode image \n"); - return false; - } - } - else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { - // flat / default llava-1.5 type embedding - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - *n_img_pos = clip_n_output_tokens(ctx_clip, img_res); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 - if (!encoded) { - LOG_ERR("Unable to encode image\n"); - - return false; - } - } - else { - // spatial_unpad llava-1.6 type embedding - // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - for (size_t i = 0; i < n_imgs; i++) { - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - const int32_t * image_grid = clip_image_grid(ctx_clip); - const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); - - std::vector> grid_pinpoints; - for (size_t i = 0; i < num_gridpoints; i += 2) { - grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); - } - - const int32_t image_size = clip_get_image_size(ctx_clip); - - struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); - - int n_img_pos_out; - clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0); - clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input); - *n_img_pos = n_img_pos_out; - - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - - // debug image/segment/normalization content: - // clip_image_u8 * tmp = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*image_feature, *tmp); - // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); - } - - LOG_INF("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); - - const int64_t t_img_enc_end_us = ggml_time_us(); - float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - - LOG_INF("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); - - return true; -} - -bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { - // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - auto n_image_embd = clip_n_mmproj_embd(ctx_clip); - if (n_image_embd != n_llama_embd) { - LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); - return false; - } - return true; -} - -bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - // Granite vision uses up to 10 patches + base patch - int num_max_patches = 11; - if (clip_is_minicpmv(ctx_clip)) { - num_max_patches = 10; - } - if (clip_is_glm(ctx_clip)) { - num_max_patches = 1; - } - float * image_embd; - if (clip_is_qwen2vl(ctx_clip)) { - // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. - image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny)); - } else { - image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model - } - if (!image_embd) { - LOG_ERR("Unable to allocate memory for image embeddings\n"); - return false; - } - - int n_img_pos; - if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) { - LOG_ERR("%s: cannot encode image, aborting\n", __func__); - free(image_embd); - return false; - } - *image_embd_out = image_embd; - *n_img_pos_out = n_img_pos; - - return true; -} - -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - -bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - - for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { - int n_eval = image_embed->n_image_pos - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { - LOG_ERR("%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) { - clip_image_u8 * img = clip_image_u8_init(); - if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { - clip_image_u8_free(img); - LOG_ERR("%s: can't load image from bytes, is it a valid image?", __func__); - return NULL; - } - - float* image_embed = NULL; - int n_image_pos = 0; - bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); - if (!image_embed_result) { - clip_image_u8_free(img); - LOG_ERR("%s: couldn't embed the image\n", __func__); - return NULL; - } - - clip_image_u8_free(img); - auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); - result->embed = image_embed; - result->n_image_pos = n_image_pos; - return result; -} - -static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) { - auto file = fopen(path, "rb"); - if (file == NULL) { - LOG_ERR("%s: can't read file %s\n", __func__, path); - return false; - } - - fseek(file, 0, SEEK_END); - auto fileSize = ftell(file); - fseek(file, 0, SEEK_SET); - - auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data - if (buffer == NULL) { - LOG_ERR("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); - perror("Memory allocation error"); - fclose(file); - return false; - } - errno = 0; - size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer - if (ferror(file)) { - LOG_ERR("read error: %s", strerror(errno)); - free(buffer); - fclose(file); - return false; - } - if (ret != (size_t) fileSize) { - LOG_ERR("unexpectedly reached end of file"); - free(buffer); - fclose(file); - return false; - } - fclose(file); // Close the file - - *bytesOut = buffer; - *sizeOut = fileSize; - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) { - unsigned char* image_bytes; - long image_bytes_length; - auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); - if (!loaded) { - LOG_ERR("%s: failed to load %s\n", __func__, image_path); - return NULL; - } - - llava_image_embed *embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length); - free(image_bytes); - - return embed; -} - -void llava_image_embed_free(struct llava_image_embed * embed) { - free(embed->embed); - free(embed); -} diff --git a/tools/mtmd/llava.h b/tools/mtmd/llava.h deleted file mode 100644 index b6feb3027b2da..0000000000000 --- a/tools/mtmd/llava.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef LLAVA_H -#define LLAVA_H - -#include "ggml.h" - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define LLAVA_API __declspec(dllexport) -# else -# define LLAVA_API __declspec(dllimport) -# endif -# else -# define LLAVA_API __attribute__ ((visibility ("default"))) -# endif -#else -# define LLAVA_API -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -struct clip_ctx; -struct llava_image_embed { - float * embed; - int n_image_pos; -}; - -/** sanity check for clip <-> llava embed size match */ -LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip); - -LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); - -/** build an image embed from image file bytes */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -/** build an image embed from a path to an image filename */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -/** free an embedding made with llava_image_embed_make_* */ -LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); - -/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); - -#ifdef __cplusplus -} -#endif - -#endif