-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[feat] JoyAI-JoyImage-Edit support #13444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Moran232
wants to merge
6
commits into
huggingface:main
Choose a base branch
from
Moran232:joyimage_edit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
54afa65
[feat] JoyAI-JoyImage-Edit support
8459759
[fix] remove rearrange
e6e6df5
[refactor] two pass when do cfg
f557113
[refactor] remove repa, use wantimetextembeding, refactor modulate code
d397b68
[refactor] Joyimage Attention refactor
9d78e4e
remove vae tiling and autocast
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,306 @@ | ||
| import argparse | ||
| import pathlib | ||
| from typing import Any, Dict, Tuple | ||
| import torch | ||
| from accelerate import init_empty_weights | ||
| from huggingface_hub import hf_hub_download, snapshot_download | ||
| from safetensors.torch import load_file | ||
| from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration | ||
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | ||
| from safetensors.torch import load_file | ||
| from diffusers import ( | ||
| AutoencoderKLWan, | ||
| JoyImageEditTransformer3DModel, | ||
| JoyImageEditPipeline, | ||
| ) | ||
| # This code is modified from convert_wan_to_diffusers.py to support input ckpt path | ||
| def convert_vae(vae_ckpt_path): | ||
| old_state_dict = torch.load(vae_ckpt_path, weights_only=True) | ||
| new_state_dict = {} | ||
|
|
||
| # Create mappings for specific components | ||
| middle_key_mapping = { | ||
| # Encoder middle block | ||
| "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", | ||
| "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", | ||
| "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", | ||
| "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", | ||
| "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", | ||
| "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", | ||
| "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", | ||
| "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", | ||
| "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", | ||
| "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", | ||
| "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", | ||
| "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", | ||
| # Decoder middle block | ||
| "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", | ||
| "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", | ||
| "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", | ||
| "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", | ||
| "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", | ||
| "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", | ||
| "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", | ||
| "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", | ||
| "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", | ||
| "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", | ||
| "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", | ||
| "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", | ||
| } | ||
|
|
||
| # Create a mapping for attention blocks | ||
| attention_mapping = { | ||
| # Encoder middle attention | ||
| "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", | ||
| "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", | ||
| "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", | ||
| "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", | ||
| "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", | ||
| # Decoder middle attention | ||
| "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", | ||
| "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", | ||
| "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", | ||
| "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", | ||
| "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", | ||
| } | ||
|
|
||
| # Create a mapping for the head components | ||
| head_mapping = { | ||
| # Encoder head | ||
| "encoder.head.0.gamma": "encoder.norm_out.gamma", | ||
| "encoder.head.2.bias": "encoder.conv_out.bias", | ||
| "encoder.head.2.weight": "encoder.conv_out.weight", | ||
| # Decoder head | ||
| "decoder.head.0.gamma": "decoder.norm_out.gamma", | ||
| "decoder.head.2.bias": "decoder.conv_out.bias", | ||
| "decoder.head.2.weight": "decoder.conv_out.weight", | ||
| } | ||
|
|
||
| # Create a mapping for the quant components | ||
| quant_mapping = { | ||
| "conv1.weight": "quant_conv.weight", | ||
| "conv1.bias": "quant_conv.bias", | ||
| "conv2.weight": "post_quant_conv.weight", | ||
| "conv2.bias": "post_quant_conv.bias", | ||
| } | ||
|
|
||
| # Process each key in the state dict | ||
| for key, value in old_state_dict.items(): | ||
| # Handle middle block keys using the mapping | ||
| if key in middle_key_mapping: | ||
| new_key = middle_key_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle attention blocks using the mapping | ||
| elif key in attention_mapping: | ||
| new_key = attention_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle head keys using the mapping | ||
| elif key in head_mapping: | ||
| new_key = head_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle quant keys using the mapping | ||
| elif key in quant_mapping: | ||
| new_key = quant_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle encoder conv1 | ||
| elif key == "encoder.conv1.weight": | ||
| new_state_dict["encoder.conv_in.weight"] = value | ||
| elif key == "encoder.conv1.bias": | ||
| new_state_dict["encoder.conv_in.bias"] = value | ||
| # Handle decoder conv1 | ||
| elif key == "decoder.conv1.weight": | ||
| new_state_dict["decoder.conv_in.weight"] = value | ||
| elif key == "decoder.conv1.bias": | ||
| new_state_dict["decoder.conv_in.bias"] = value | ||
| # Handle encoder downsamples | ||
| elif key.startswith("encoder.downsamples."): | ||
| # Convert to down_blocks | ||
| new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") | ||
|
|
||
| # Convert residual block naming but keep the original structure | ||
| if ".residual.0.gamma" in new_key: | ||
| new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") | ||
| elif ".residual.2.bias" in new_key: | ||
| new_key = new_key.replace(".residual.2.bias", ".conv1.bias") | ||
| elif ".residual.2.weight" in new_key: | ||
| new_key = new_key.replace(".residual.2.weight", ".conv1.weight") | ||
| elif ".residual.3.gamma" in new_key: | ||
| new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") | ||
| elif ".residual.6.bias" in new_key: | ||
| new_key = new_key.replace(".residual.6.bias", ".conv2.bias") | ||
| elif ".residual.6.weight" in new_key: | ||
| new_key = new_key.replace(".residual.6.weight", ".conv2.weight") | ||
| elif ".shortcut.bias" in new_key: | ||
| new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") | ||
| elif ".shortcut.weight" in new_key: | ||
| new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle decoder upsamples | ||
| elif key.startswith("decoder.upsamples."): | ||
| # Convert to up_blocks | ||
| parts = key.split(".") | ||
| block_idx = int(parts[2]) | ||
|
|
||
| # Group residual blocks | ||
| if "residual" in key: | ||
| if block_idx in [0, 1, 2]: | ||
| new_block_idx = 0 | ||
| resnet_idx = block_idx | ||
| elif block_idx in [4, 5, 6]: | ||
| new_block_idx = 1 | ||
| resnet_idx = block_idx - 4 | ||
| elif block_idx in [8, 9, 10]: | ||
| new_block_idx = 2 | ||
| resnet_idx = block_idx - 8 | ||
| elif block_idx in [12, 13, 14]: | ||
| new_block_idx = 3 | ||
| resnet_idx = block_idx - 12 | ||
| else: | ||
| # Keep as is for other blocks | ||
| new_state_dict[key] = value | ||
| continue | ||
|
|
||
| # Convert residual block naming | ||
| if ".residual.0.gamma" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" | ||
| elif ".residual.2.bias" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" | ||
| elif ".residual.2.weight" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" | ||
| elif ".residual.3.gamma" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" | ||
| elif ".residual.6.bias" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" | ||
| elif ".residual.6.weight" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" | ||
| else: | ||
| new_key = key | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle shortcut connections | ||
| elif ".shortcut." in key: | ||
| if block_idx == 4: | ||
| new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") | ||
| new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
| new_key = new_key.replace(".shortcut.", ".conv_shortcut.") | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle upsamplers | ||
| elif ".resample." in key or ".time_conv." in key: | ||
| if block_idx == 3: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") | ||
| elif block_idx == 7: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") | ||
| elif block_idx == 11: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
|
|
||
| new_state_dict[new_key] = value | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
| new_state_dict[new_key] = value | ||
| else: | ||
| # Keep other keys unchanged | ||
| new_state_dict[key] = value | ||
|
|
||
| with init_empty_weights(): | ||
| vae = AutoencoderKLWan() | ||
| vae.load_state_dict(new_state_dict, strict=True, assign=True) | ||
| return vae | ||
|
|
||
| def get_transformer_config() -> Tuple[Dict[str, Any], ...]: | ||
| config = { | ||
| "diffusers_config": { | ||
| "hidden_size": 4096, | ||
| "in_channels": 16, | ||
| "heads_num": 32, | ||
| "mm_double_blocks_depth": 40, | ||
| "out_channels": 16, | ||
| "patch_size": [1, 2, 2], | ||
| "rope_dim_list": [16, 56, 56], | ||
| "text_states_dim": 4096, | ||
| "rope_type": "rope", | ||
| "dit_modulation_type": "wanx", | ||
| "unpatchify_new": True, | ||
| "rope_theta": 10000, | ||
| }, | ||
| } | ||
| return config | ||
| def convert_transformer(ckpt_path: str): | ||
| checkpoint = torch.load(ckpt_path, weights_only=True) | ||
| if "model" in checkpoint: | ||
| original_state_dict = checkpoint["model"] | ||
| else: | ||
| original_state_dict = checkpoint | ||
| config = get_transformer_config() | ||
| with init_empty_weights(): | ||
| transformer = JoyImageEditTransformer3DModel(**config['diffusers_config']) | ||
| transformer.load_state_dict(original_state_dict, strict=True, assign=True) | ||
| return transformer | ||
|
|
||
| def get_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | ||
| ) | ||
| parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") | ||
| parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") | ||
| parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") | ||
| parser.add_argument("--save_pipeline", action="store_true") | ||
| parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
| parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") | ||
| parser.add_argument("--flow_shift", type=float, default=7.0) | ||
| return parser.parse_args() | ||
|
|
||
| DTYPE_MAPPING = { | ||
| "fp32": torch.float32, | ||
| "fp16": torch.float16, | ||
| "bf16": torch.bfloat16, | ||
| } | ||
| if __name__ == "__main__": | ||
| args = get_args() | ||
| transformer = None | ||
| vae = None | ||
| dtype = DTYPE_MAPPING[args.dtype] | ||
|
|
||
| if args.save_pipeline: | ||
| assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None | ||
| assert args.text_encoder_path is not None | ||
| # assert args.tokenizer_path is not None | ||
| if args.transformer_ckpt_path is not None: | ||
| transformer = convert_transformer(args.transformer_ckpt_path) | ||
| transformer = transformer.to(dtype=dtype) | ||
| if not args.save_pipeline: | ||
| transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| if args.vae_ckpt_path is not None: | ||
| vae = convert_vae(args.vae_ckpt_path) | ||
| vae = vae.to(dtype=dtype) | ||
| if not args.save_pipeline: | ||
| vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| if args.save_pipeline: | ||
| processor = AutoProcessor.from_pretrained(args.text_encoder_path) | ||
| text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda") | ||
| tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) | ||
| flow_shift = 1.5 | ||
| scheduler = FlowMatchEulerDiscreteScheduler( | ||
| num_train_timesteps=1000, shift=flow_shift | ||
| ) | ||
| transformer = transformer.to("cuda") | ||
| vae = vae.to("cuda") | ||
| pipe = JoyImageEditPipeline( | ||
| processor=processor, | ||
| transformer=transformer, | ||
| text_encoder=text_encoder, | ||
| tokenizer=tokenizer, | ||
| vae=vae, | ||
| scheduler=scheduler, | ||
| ).to("cuda") | ||
| pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| processor.save_pretrained(f"{args.output_path}/processor") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,7 @@ | |
| "accelerate>=0.31.0", | ||
| "compel==0.1.8", | ||
| "datasets", | ||
| "einops", | ||
| "filelock", | ||
| "flax>=0.4.1", | ||
| "ftfy", | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're no longer using
einops, we can remove it fromsetup.py.