From 0392cd8824b1c01d4c43a21564f505b8a19cbb9a Mon Sep 17 00:00:00 2001 From: zhujian <2469395556@qq.com> Date: Tue, 26 May 2026 20:46:32 +0800 Subject: [PATCH 1/3] fix: resolve issue #13811 --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 477697fadb64..ce75565b5e8a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1717,12 +1717,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( + cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input[0:1]]).to( device=cond_model_input.device ) - cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], -1, model_input_ids.shape[-1] + cond_model_input_ids = cond_model_input_ids.expand( + cond_model_input.shape[0], -1, -1 ) # Sample noise that we'll add to the latents From 38a31c51fc741dba780f5732392e160b0ff6db22 Mon Sep 17 00:00:00 2001 From: zhujian <2469395556@qq.com> Date: Thu, 28 May 2026 10:49:47 +0800 Subject: [PATCH 2/3] fix: format train_dreambooth_lora_flux2_img2img.py --- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index ce75565b5e8a..f11118f5cae5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1720,9 +1720,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input_ids = Flux2Pipeline._prepare_image_ids([cond_model_input[0:1]]).to( device=cond_model_input.device ) - cond_model_input_ids = cond_model_input_ids.expand( - cond_model_input.shape[0], -1, -1 - ) + cond_model_input_ids = cond_model_input_ids.expand(cond_model_input.shape[0], -1, -1) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) From 0950595dd4b62869a950569629db4f02b55648c9 Mon Sep 17 00:00:00 2001 From: zhujian <2469395556@qq.com> Date: Thu, 28 May 2026 11:02:05 +0800 Subject: [PATCH 3/3] fix: train_dreambooth_lora_flux2_klein_img2img.py --- .../train_dreambooth_lora_flux2_klein_img2img.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 63862eed9f1e..80a55a8b4307 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1663,13 +1663,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] - cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to( + cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids([cond_model_input[0:1]]).to( device=cond_model_input.device ) - cond_model_input_ids = cond_model_input_ids.view( - cond_model_input.shape[0], -1, model_input_ids.shape[-1] - ) + cond_model_input_ids = cond_model_input_ids.expand(cond_model_input.shape[0], -1, -1) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input)