Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/flux2.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main

Expand All @@ -31,6 +33,8 @@
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
Expand Down
39 changes: 33 additions & 6 deletions src/auto_encoder_kl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,36 @@ class AutoEncoderKLModel : public GGMLBlock {
bool double_z = true;
} dd_config;

static std::string get_tensor_name(const std::string& prefix, const std::string& name) {
return prefix.empty() ? name : prefix + "." + name;
}

void detect_decoder_ch(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int& decoder_ch) {
auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight"));
if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) {
int last_ch_mult = dd_config.ch_mult.back();
int64_t conv_in_out_channels = conv_in_iter->second.ne[3];
if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) {
decoder_ch = static_cast<int>(conv_in_out_channels / last_ch_mult);
LOG_INFO("vae decoder: ch = %d", decoder_ch);
} else {
LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)",
get_tensor_name(prefix, "decoder.conv_in.weight").c_str(),
conv_in_out_channels,
last_ch_mult);
}
}
}

public:
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "")
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) {
Expand All @@ -519,7 +544,9 @@ class AutoEncoderKLModel : public GGMLBlock {
if (use_video_decoder) {
use_quant = false;
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
int decoder_ch = dd_config.ch;
detect_decoder_ch(tensor_storage_map, prefix, decoder_ch);
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(decoder_ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
Expand Down Expand Up @@ -662,7 +689,7 @@ struct AutoEncoderKL : public VAE {
break;
}
}
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix);
ae.init(params_ctx, tensor_storage_map, prefix);
}

Expand Down
Loading