diff --git a/docs/flux2.md b/docs/flux2.md index 1524478cc..11202e919 100644 --- a/docs/flux2.md +++ b/docs/flux2.md @@ -8,6 +8,8 @@ - gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main - Download vae - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main - Download Mistral-Small-3.2-24B-Instruct-2506-GGUF - gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main @@ -31,6 +33,8 @@ - gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main - Download vae - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main - Download Qwen3 4b - safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders - gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main diff --git a/src/auto_encoder_kl.hpp b/src/auto_encoder_kl.hpp index 039fb9df3..d4283959d 100644 --- a/src/auto_encoder_kl.hpp +++ b/src/auto_encoder_kl.hpp @@ -501,11 +501,36 @@ class AutoEncoderKLModel : public GGMLBlock { bool double_z = true; } dd_config; + static std::string get_tensor_name(const std::string& prefix, const std::string& name) { + return prefix.empty() ? name : prefix + "." + name; + } + + void detect_decoder_ch(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int& decoder_ch) { + auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight")); + if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) { + int last_ch_mult = dd_config.ch_mult.back(); + int64_t conv_in_out_channels = conv_in_iter->second.ne[3]; + if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) { + decoder_ch = static_cast(conv_in_out_channels / last_ch_mult); + LOG_INFO("vae decoder: ch = %d", decoder_ch); + } else { + LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)", + get_tensor_name(prefix, "decoder.conv_in.weight").c_str(), + conv_in_out_channels, + last_ch_mult); + } + } + } + public: - AutoEncoderKLModel(SDVersion version = VERSION_SD1, - bool decode_only = true, - bool use_linear_projection = false, - bool use_video_decoder = false) + AutoEncoderKLModel(SDVersion version = VERSION_SD1, + bool decode_only = true, + bool use_linear_projection = false, + bool use_video_decoder = false, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& prefix = "") : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { if (sd_version_is_dit(version)) { if (sd_version_is_flux2(version)) { @@ -519,7 +544,9 @@ class AutoEncoderKLModel : public GGMLBlock { if (use_video_decoder) { use_quant = false; } - blocks["decoder"] = std::shared_ptr(new Decoder(dd_config.ch, + int decoder_ch = dd_config.ch; + detect_decoder_ch(tensor_storage_map, prefix, decoder_ch); + blocks["decoder"] = std::shared_ptr(new Decoder(decoder_ch, dd_config.out_ch, dd_config.ch_mult, dd_config.num_res_blocks, @@ -662,7 +689,7 @@ struct AutoEncoderKL : public VAE { break; } } - ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder); + ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix); ae.init(params_ctx, tensor_storage_map, prefix); }