From d02d23de2630b649a4706535a4bc76f94d5a0de9 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 9 Apr 2026 17:50:02 +0800 Subject: [PATCH 1/8] auto intergrate joyimage model --- diffsynth/configs/model_configs.py | 34 +- diffsynth/models/joyai_image_dit.py | 696 ++++++++++++++++++ diffsynth/models/joyai_image_text_encoder.py | 100 +++ .../state_dict_converters/joyai_image_dit.py | 24 + .../joyai_image_text_encoder.py | 20 + test.py | 5 + 6 files changed, 878 insertions(+), 1 deletion(-) create mode 100644 diffsynth/models/joyai_image_dit.py create mode 100644 diffsynth/models/joyai_image_text_encoder.py create mode 100644 diffsynth/utils/state_dict_converters/joyai_image_dit.py create mode 100644 diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py create mode 100644 test.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 4222202ea..40604aeb7 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -884,4 +884,36 @@ "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series +joyai_image_series = [ + { + # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth") + "model_hash": "56592ddfd7d0249d3aa527d24161a863", + "model_name": "joyai_image_dit", + "model_class": "diffsynth.models.joyai_image_dit.Transformer3DModel", + "extra_kwargs": { + "patch_size": [1, 2, 2], + "in_channels": 16, + "out_channels": 16, + "hidden_size": 4096, + "heads_num": 32, + "text_states_dim": 4096, + "mlp_width_ratio": 4.0, + "mm_double_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "rope_type": "rope", + "dit_modulation_type": "wanx", + "attn_backend": "torch_spda", + "theta": 256, + }, + "state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_dit.JoyAIImageDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") + "model_hash": "2d11bf14bba8b4e87477c8199a895403", + "model_name": "joyai_image_text_encoder", + "model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter", + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series diff --git a/diffsynth/models/joyai_image_dit.py b/diffsynth/models/joyai_image_dit.py new file mode 100644 index 000000000..1e73fe8b6 --- /dev/null +++ b/diffsynth/models/joyai_image_dit.py @@ -0,0 +1,696 @@ +import math +from typing import List, Tuple, Optional, Union, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = scale * emb + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + return get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + self.act = nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + self.post_act = nn.SiLU() if post_act_fn == "silu" else None + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class PixArtAlphaTextProjection(nn.Module): + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + else: + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + if activation_fn == "gelu-approximate": + self.proj = nn.Linear(dim, inner_dim, bias=bias) + self.act = lambda x: F.gelu(x, approximate="tanh") + elif activation_fn == "gelu": + self.proj = nn.Linear(dim, inner_dim, bias=bias) + self.act = F.gelu + else: + self.proj = nn.Linear(dim, inner_dim, bias=bias) + self.act = F.gelu + self.drop = nn.Dropout(dropout) + self.out_proj = nn.Linear(inner_dim, dim_out, bias=bias) + if final_dropout: + self.final_drop = nn.Dropout(dropout) + else: + self.final_drop = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.drop(hidden_states) + hidden_states = self.out_proj(hidden_states) + if self.final_drop is not None: + hidden_states = self.final_drop(hidden_states) + return hidden_states + + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + + +def flash_attn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): + q_ = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) + k_ = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) + v_ = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) + if FLASH_ATTN_3_AVAILABLE: + x = flash_attn_interface.flash_attn_varlen_func( + q_, k_, v_, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + ) + if isinstance(x, tuple): + x = x[0] + else: + x = flash_attn.flash_attn_varlen_func( + q_, k_, v_, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + ) + batch_size = cu_seqlens_q.shape[0] // 2 + return x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) + + +def get_cu_seqlens(text_mask, img_len): + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + return cu_seqlens + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + if len(args) == 0: + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = _to_tuple(args[1], dim=dim) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + grid = torch.stack(grid, dim=0) + return grid + + +def reshape_for_broadcast(freqs_cis, x, head_first=False): + ndim = x.ndim + assert 0 <= 1 < ndim + if isinstance(freqs_cis, tuple): + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb(xq, xk, freqs_cis, head_first=False): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + +def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0): + if isinstance(pos, int): + pos = torch.arange(pos).float() + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = torch.outer(pos * interpolation_factor, freqs) + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) + return freqs_cos, freqs_sin + else: + return torch.polar(torch.ones_like(freqs), freqs) + + +def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False, + txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0): + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + if isinstance(theta_rescale_factor, (int, float)): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + if isinstance(interpolation_factor, (int, float)): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], grid[i].reshape(-1), theta, + use_real=use_real, theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs.append(emb) + if use_real: + vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1)) + else: + vis_emb = torch.cat(embs, dim=1) + if txt_rope_size is not None: + embs_txt = [] + vis_max_ids = grid.view(-1).max().item() + grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1 + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], grid_txt, theta, + use_real=use_real, theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs_txt.append(emb) + if use_real: + txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1)) + else: + txt_emb = torch.cat(embs_txt, dim=1) + else: + txt_emb = None + return vis_emb, txt_emb + + +class ModulateWan(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + factory_kwargs = {"dtype": dtype, "device": device} + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5, + requires_grad=True + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +def modulate(x, shift=None, scale=None): + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + if modulate_type == 'wanx': + return ModulateWan(hidden_size, factor, **factory_kwargs) + raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class MMDoubleStreamBlock(nn.Module): + """ + A multimodal dit block with separate modulation for + text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + dit_modulation_type: Optional[str] = "wanx", + attn_backend: str = 'flash_attn', + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_backend = attn_backend + self.dit_modulation_type = dit_modulation_type + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = load_modulation( + modulate_type=self.dit_modulation_type, + hidden_size=hidden_size, factor=6, **factory_kwargs, + ) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + self.txt_mod = load_modulation( + modulate_type=self.dit_modulation_type, + hidden_size=hidden_size, factor=6, **factory_kwargs, + ) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + vis_freqs_cis: tuple = None, + txt_freqs_cis: tuple = None, + attn_kwargs: Optional[dict] = {}, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, img_mod1_scale, img_mod1_gate, + img_mod2_shift, img_mod2_scale, img_mod2_gate, + ) = self.img_mod(vec) + ( + txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, + txt_mod2_shift, txt_mod2_scale, txt_mod2_gate, + ) = self.txt_mod(vec) + + img_modulated = self.img_norm1(img) + img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + if vis_freqs_cis is not None: + img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) + img_q, img_k = img_qq, img_kk + + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + if txt_freqs_cis is not None: + raise NotImplementedError("RoPE text is not supported for inference") + + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + + if self.attn_backend == 'flash_attn': + cu_seqlens_q = attn_kwargs['cu_seqlens_q'] + cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] + max_seqlen_q = attn_kwargs['max_seqlen_q'] + max_seqlen_kv = attn_kwargs['max_seqlen_kv'] + attn_out = flash_attn_varlen( + q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + ) + else: + q_sdpa = rearrange(q, "b l h c -> b h l c") + k_sdpa = rearrange(k, "b l h c -> b h l c") + v_sdpa = rearrange(v, "b l h c -> b h l c") + attn_out_sdpa = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa) + attn_out = rearrange(attn_out_sdpa, "b h l c -> b l h c") + + attn_out = attn_out.flatten(2, 3) + img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:] + + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class Transformer3DModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int = None, + hidden_size: int = 4096, + heads_num: int = 32, + text_states_dim: int = 4096, + mlp_width_ratio: float = 4.0, + mm_double_blocks_depth: int = 40, + rope_dim_list: List[int] = [16, 56, 56], + rope_type: str = 'rope', + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + dit_modulation_type: str = "wanx", + attn_backend: str = 'flash_attn', + theta: int = 256, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.heads_num = heads_num + self.rope_dim_list = rope_dim_list + self.dit_modulation_type = dit_modulation_type + self.mm_double_blocks_depth = mm_double_blocks_depth + self.attn_backend = attn_backend + self.rope_type = rope_type + self.theta = theta + + factory_kwargs = {"device": device, "dtype": dtype} + + if hidden_size % heads_num != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + self.condition_embedder = WanTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_states_dim, + ) + + self.double_blocks = nn.ModuleList([ + MMDoubleStreamBlock( + self.hidden_size, self.heads_num, + mlp_width_ratio=mlp_width_ratio, + dit_modulation_type=self.dit_modulation_type, + attn_backend=attn_backend, + **factory_kwargs, + ) + for _ in range(mm_double_blocks_depth) + ]) + + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs) + + def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): + target_ndim = 3 + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size + head_dim = self.hidden_size // self.heads_num + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim + vis_freqs, txt_freqs = get_nd_rotary_pos_embed( + rope_dim_list, vis_rope_size, + txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None, + theta=self.theta, use_real=True, theta_rescale_factor=1, + ) + return vis_freqs, txt_freqs + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + return_dict: bool = True, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + from ..core.gradient import gradient_checkpoint_forward + + is_multi_item = (len(hidden_states.shape) == 6) + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" + hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) + hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w') + + batch_size, _, ot, oh, ow = hidden_states.shape + tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2] + + if encoder_hidden_states_mask is None: + encoder_hidden_states_mask = torch.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), + dtype=torch.bool, + ).to(encoder_hidden_states.device) + + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed( + vis_rope_size=(tt, th, tw), + txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None, + ) + + attn_kwargs = {'thw': [tt, th, tw], 'txt_len': txt_seq_len} + if self.attn_backend == 'flash_attn': + cu_seqlens_q = get_cu_seqlens(encoder_hidden_states_mask, img_seq_len) + attn_kwargs.update({ + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_kv': cu_seqlens_q, + 'max_seqlen_q': img_seq_len + txt_seq_len, + 'max_seqlen_kv': img_seq_len + txt_seq_len, + }) + + for block in self.double_blocks: + img, txt = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + img=img, txt=txt, vec=vec, + vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis, + attn_kwargs=attn_kwargs, + ) + + img_len = img.shape[1] + x = torch.cat((img, txt), 1) + img = x[:, :img_len, ...] + + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + + if is_multi_item: + img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items) + if num_items > 1: + img = torch.cat([img[:, 1:], img[:, :1]], dim=1) + + return (img, txt) + + def unpatchify(self, x, t, h, w): + c = self.out_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = torch.einsum("nthwopqc->nctohpwq", x) + return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) diff --git a/diffsynth/models/joyai_image_text_encoder.py b/diffsynth/models/joyai_image_text_encoder.py new file mode 100644 index 000000000..2035e6d26 --- /dev/null +++ b/diffsynth/models/joyai_image_text_encoder.py @@ -0,0 +1,100 @@ +import torch +from typing import Optional + + +class JoyAIImageTextEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration + + config = Qwen3VLConfig( + text_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-6, + "rope_scaling": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + "rope_theta": 5000000, + "use_cache": True, + "vocab_size": 151936, + }, + vision_config={ + "deepstack_visual_indexes": [8, 16, 24], + "depth": 27, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "model_type": "qwen3_vl", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 4096, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=False, + ) + + self.model = Qwen3VLForConditionalGeneration(config) + self.config = config + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + ): + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + return outputs.hidden_states + + +class JoyAIImageTextEncoderStateDictConverter: + def from_civitai(self, state_dict): + return state_dict + + def from_diffusers(self, state_dict): + return state_dict diff --git a/diffsynth/utils/state_dict_converters/joyai_image_dit.py b/diffsynth/utils/state_dict_converters/joyai_image_dit.py new file mode 100644 index 000000000..77921f26b --- /dev/null +++ b/diffsynth/utils/state_dict_converters/joyai_image_dit.py @@ -0,0 +1,24 @@ +def JoyAIImageDiTStateDictConverter(state_dict): + """Convert JoyAI-Image DiT checkpoint to model state dict. + + Handle: + 1. "model." prefix stripping from checkpoint + 2. FeedForward key mapping: diffusers uses "net.0.proj" / "net.2" + while DiffSynth uses "proj" / "out_proj" + """ + state_dict_ = {} + for name in state_dict: + if name.startswith("model."): + name = name[len("model."):] + + # Map diffusers FeedForward keys to DiffSynth keys + # Pattern: double_blocks.N.{img_mlp|txt_mlp}.net.0.proj.* -> double_blocks.N.{img_mlp|txt_mlp}.proj.* + new_name = name + if ".net.0.proj." in name: + new_name = name.replace(".net.0.proj.", ".proj.") + elif ".net.2." in name: + new_name = name.replace(".net.2.", ".out_proj.") + + state_dict_[new_name] = state_dict[name] + + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py b/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py new file mode 100644 index 000000000..7fbfb9a99 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/joyai_image_text_encoder.py @@ -0,0 +1,20 @@ +def JoyAIImageTextEncoderStateDictConverter(state_dict): + """Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys. + + Mapping (checkpoint -> wrapper): + - lm_head.weight -> model.lm_head.weight + - model.language_model.* -> model.model.language_model.* + - model.visual.* -> model.model.visual.* + """ + state_dict_ = {} + for key in state_dict: + if key == "lm_head.weight": + new_key = "model.lm_head.weight" + elif key.startswith("model.language_model."): + new_key = "model.model." + key[len("model."):] + elif key.startswith("model.visual."): + new_key = "model.model." + key[len("model."):] + else: + new_key = key + state_dict_[new_key] = state_dict[key] + return state_dict_ diff --git a/test.py b/test.py new file mode 100644 index 000000000..de2b2b0e0 --- /dev/null +++ b/test.py @@ -0,0 +1,5 @@ +from diffsynth.models.model_loader import ModelPool + +pool = ModelPool() +pool.auto_load_model("models/jd-opensource/JoyAI-Image-Edit/vae/Wan2.1_VAE.pth") +model = pool.fetch_model("wan_video_vae") \ No newline at end of file From 4a193461071884f289e1261ccb04f4fd3da4fa4a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 10 Apr 2026 18:19:41 +0800 Subject: [PATCH 2/8] joyimage pipeline --- diffsynth/configs/model_configs.py | 5 +- diffsynth/diffusion/flow_match.py | 12 + diffsynth/models/joyai_image_common.py | 135 ++++++++ diffsynth/models/joyai_image_dit.py | 73 +--- diffsynth/models/joyai_image_text_encoder.py | 25 +- diffsynth/pipelines/joyai_image.py | 321 ++++++++++++++++++ .../model_inference/JoyAI-Image-Edit.py | 44 +++ .../JoyAI-Image-Edit.py | 52 +++ .../model_training/full/JoyAI-Image-Edit.sh | 17 + .../full/accelerate_config_zero2offload.yaml | 22 ++ .../model_training/lora/JoyAI-Image-Edit.sh | 19 ++ .../lora/accelerate_config.yaml | 22 ++ examples/joyai_image/model_training/train.py | 170 ++++++++++ .../validate_full/JoyAI-Image-Edit.py | 25 ++ .../validate_lora/JoyAI-Image-Edit.py | 25 ++ 15 files changed, 877 insertions(+), 90 deletions(-) create mode 100644 diffsynth/models/joyai_image_common.py create mode 100644 diffsynth/pipelines/joyai_image.py create mode 100644 examples/joyai_image/model_inference/JoyAI-Image-Edit.py create mode 100644 examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py create mode 100644 examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh create mode 100644 examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml create mode 100644 examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh create mode 100644 examples/joyai_image/model_training/lora/accelerate_config.yaml create mode 100644 examples/joyai_image/model_training/train.py create mode 100644 examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py create mode 100644 examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 40604aeb7..e740849c5 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -902,13 +902,12 @@ "rope_dim_list": [16, 56, 56], "rope_type": "rope", "dit_modulation_type": "wanx", - "attn_backend": "torch_spda", - "theta": 256, + "theta": 10000, }, "state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_dit.JoyAIImageDiTStateDictConverter", }, { - # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") + # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors") "model_hash": "2d11bf14bba8b4e87477c8199a895403", "model_name": "joyai_image_text_encoder", "model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder", diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 208fb1e0c..55a1b02f4 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -146,6 +146,18 @@ def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift timesteps[timestep_id] = timestep return sigmas, timesteps + @staticmethod + def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 4.0 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + @staticmethod def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None): num_train_timesteps = 1000 diff --git a/diffsynth/models/joyai_image_common.py b/diffsynth/models/joyai_image_common.py new file mode 100644 index 000000000..501a23360 --- /dev/null +++ b/diffsynth/models/joyai_image_common.py @@ -0,0 +1,135 @@ +from PIL import Image +from typing import Tuple +import math +import torchvision.transforms.functional as TF + +class BucketGroup: + """Manages dynamic batch grouping buckets for image inference.""" + + def __init__( + self, + bucket_configs: list[tuple[int, int, int, int, int]], + prioritize_frame_matching: bool = True, + ): + """ + Initialize bucket group with predefined configurations. + + Args: + bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples + prioritize_frame_matching: Unused, kept for API compatibility. + """ + self.bucket_configs = [tuple(b) for b in bucket_configs] + + def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: + """ + Find the best matching bucket for given media dimensions. + + Args: + media_shape: (num_items, num_frames, height, width) of input media + + Returns: + Best matching bucket as (batch_size, num_items, num_frames, height, width) + """ + num_items, num_frames, height, width = media_shape + target_aspect_ratio = height / width + + if num_frames != 1: + raise ValueError( + f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}") + + valid_buckets = [ + b for b in self.bucket_configs + if b[1] == num_items and b[2] == 1 + ] + if not valid_buckets: + raise ValueError( + f"No image buckets found for shape {media_shape}") + + return min( + valid_buckets, + key=lambda bucket: abs( + (bucket[3] / bucket[4]) - target_aspect_ratio) + ) + + def __repr__(self) -> str: + return ( + f"BucketGroup(" + f"total_buckets={len(self.bucket_configs)}, " + f"configs={self.bucket_configs})" + ) + + +def _generate_hw_buckets(base_height=256, base_width=256, step_width=16, step_height=16, max_ratio=4.0) -> list[tuple[int, int, int, int, int]]: + """Generate dimension buckets based on aspect ratios.""" + buckets = [] + target_pixels = base_height * base_width + + height = target_pixels // step_width + width = step_width + + while height >= step_height: + if max(height, width) / min(height, width) <= max_ratio: + buckets.append((1, 1, 1, height, width)) + if height * (width + step_width) <= target_pixels: + width += step_width + else: + height -= step_height + + return buckets + + +def generate_video_image_bucket(basesize=256, min_temporal=65, max_temporal=129, bs_img=8, bs_vid=1, bs_mimg=4, min_items=1, max_items=1): + """Generate bucket configs for image inference. + + Returns: + List of (batch_size, num_items, num_frames, height, width) tuples. + """ + assert basesize in [ + 256, 512, 768, 1024], f"[generate_video_image_bucket] wrong basesize {basesize}" + bucket_list = [] + + base_bucket_list = _generate_hw_buckets() + # image + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_img + bucket_list.append(bucket) + # multiple images + for num_items in range(min_items, max_items + 1): + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_mimg + bucket[1] = num_items + bucket_list.append(bucket) + # spatial resize + if basesize > 256: + ratio = basesize // 256 + + def resize(bucket, r): + bucket[-2] *= r + bucket[-1] *= r + return bucket + bucket_list = [resize(bucket, ratio) for bucket in bucket_list] + return bucket_list + + +def _dynamic_resize_from_bucket(image: Image, basesize: int = 512): + def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + w, h = img.size # PIL: (width, height) + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) + img = TF.resize(img, (resize_h, resize_w), + interpolation=TF.InterpolationMode.BILINEAR, antialias=True) + img = TF.center_crop(img, target_size) + return img + + bucket_config = generate_video_image_bucket( + basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2 + ) + bucket_group = BucketGroup(bucket_config) + img_w, img_h = image.size + bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w)) + target_height, target_width = bucket[-2], bucket[-1] # (height, width) + img_proc = resize_center_crop(image, (target_height, target_width)) + return img_proc diff --git a/diffsynth/models/joyai_image_dit.py b/diffsynth/models/joyai_image_dit.py index 1e73fe8b6..23de2c03a 100644 --- a/diffsynth/models/joyai_image_dit.py +++ b/diffsynth/models/joyai_image_dit.py @@ -1,11 +1,14 @@ import math -from typing import List, Tuple, Optional, Union, Dict +from typing import List, Optional, Tuple, Union, Dict import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from ..core.attention import attention_forward +from ..core.gradient import gradient_checkpoint_forward + def get_timestep_embedding( timesteps: torch.Tensor, @@ -147,36 +150,6 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: return hidden_states -try: - import flash_attn_interface - FLASH_ATTN_3_AVAILABLE = True -except ModuleNotFoundError: - FLASH_ATTN_3_AVAILABLE = False - -try: - import flash_attn - FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: - FLASH_ATTN_2_AVAILABLE = False - - -def flash_attn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): - q_ = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) - k_ = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) - v_ = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) - if FLASH_ATTN_3_AVAILABLE: - x = flash_attn_interface.flash_attn_varlen_func( - q_, k_, v_, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv - ) - if isinstance(x, tuple): - x = x[0] - else: - x = flash_attn.flash_attn_varlen_func( - q_, k_, v_, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv - ) - batch_size = cu_seqlens_q.shape[0] // 2 - return x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) - def get_cu_seqlens(text_mask, img_len): batch_size = text_mask.shape[0] @@ -395,11 +368,9 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, dit_modulation_type: Optional[str] = "wanx", - attn_backend: str = 'flash_attn', ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.attn_backend = attn_backend self.dit_modulation_type = dit_modulation_type self.heads_num = heads_num head_dim = hidden_size // heads_num @@ -472,20 +443,11 @@ def forward( k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) - if self.attn_backend == 'flash_attn': - cu_seqlens_q = attn_kwargs['cu_seqlens_q'] - cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] - max_seqlen_q = attn_kwargs['max_seqlen_q'] - max_seqlen_kv = attn_kwargs['max_seqlen_kv'] - attn_out = flash_attn_varlen( - q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv - ) - else: - q_sdpa = rearrange(q, "b l h c -> b h l c") - k_sdpa = rearrange(k, "b l h c -> b h l c") - v_sdpa = rearrange(v, "b l h c -> b h l c") - attn_out_sdpa = F.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa) - attn_out = rearrange(attn_out_sdpa, "b h l c -> b l h c") + # Use DiffSynth unified attention + attn_out = attention_forward( + q, k, v, + q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", + ) attn_out = attn_out.flatten(2, 3) img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:] @@ -551,7 +513,6 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, dit_modulation_type: str = "wanx", - attn_backend: str = 'flash_attn', theta: int = 256, ): super().__init__() @@ -562,7 +523,6 @@ def __init__( self.rope_dim_list = rope_dim_list self.dit_modulation_type = dit_modulation_type self.mm_double_blocks_depth = mm_double_blocks_depth - self.attn_backend = attn_backend self.rope_type = rope_type self.theta = theta @@ -585,7 +545,6 @@ def __init__( self.hidden_size, self.heads_num, mlp_width_ratio=mlp_width_ratio, dit_modulation_type=self.dit_modulation_type, - attn_backend=attn_backend, **factory_kwargs, ) for _ in range(mm_double_blocks_depth) @@ -620,8 +579,6 @@ def forward( use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - from ..core.gradient import gradient_checkpoint_forward - is_multi_item = (len(hidden_states.shape) == 6) num_items = 0 if is_multi_item: @@ -653,16 +610,6 @@ def forward( txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None, ) - attn_kwargs = {'thw': [tt, th, tw], 'txt_len': txt_seq_len} - if self.attn_backend == 'flash_attn': - cu_seqlens_q = get_cu_seqlens(encoder_hidden_states_mask, img_seq_len) - attn_kwargs.update({ - 'cu_seqlens_q': cu_seqlens_q, - 'cu_seqlens_kv': cu_seqlens_q, - 'max_seqlen_q': img_seq_len + txt_seq_len, - 'max_seqlen_kv': img_seq_len + txt_seq_len, - }) - for block in self.double_blocks: img, txt = gradient_checkpoint_forward( block, @@ -670,7 +617,7 @@ def forward( use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, img=img, txt=txt, vec=vec, vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis, - attn_kwargs=attn_kwargs, + attn_kwargs={}, ) img_len = img.shape[1] diff --git a/diffsynth/models/joyai_image_text_encoder.py b/diffsynth/models/joyai_image_text_encoder.py index 2035e6d26..71e606340 100644 --- a/diffsynth/models/joyai_image_text_encoder.py +++ b/diffsynth/models/joyai_image_text_encoder.py @@ -63,38 +63,15 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values=None, - inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, ): outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - position_ids=position_ids, attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - cache_position=cache_position, **kwargs, ) - return outputs.hidden_states - - -class JoyAIImageTextEncoderStateDictConverter: - def from_civitai(self, state_dict): - return state_dict - - def from_diffusers(self, state_dict): - return state_dict + return outputs diff --git a/diffsynth/pipelines/joyai_image.py b/diffsynth/pipelines/joyai_image.py new file mode 100644 index 000000000..13e6999ca --- /dev/null +++ b/diffsynth/pipelines/joyai_image.py @@ -0,0 +1,321 @@ +import torch +from PIL import Image +from typing import Union, List +from tqdm import tqdm +from einops import rearrange +from typing import Optional, Union + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit +from ..models.joyai_image_dit import Transformer3DModel +from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder +from ..models.joyai_image_common import _dynamic_resize_from_bucket +from ..models.wan_video_vae import WanVideoVAE + +# ============================================================ +# JoyAIImagePipeline +# ============================================================ +class JoyAIImagePipeline(BasePipeline): + """ + Pipeline for JoyAI-Image model. + """ + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("Wan") + self.text_encoder: JoyAIImageTextEncoder = None + self.dit: Transformer3DModel = None + self.vae: WanVideoVAE = None + self.processor = None + self.in_iteration_models = ("dit",) + + self.units = [ + JoyAIImageUnit_ShapeChecker(), + JoyAIImageUnit_EditImageEmbedder(), + JoyAIImageUnit_PromptEmbedder(), + JoyAIImageUnit_NoiseInitializer(), + JoyAIImageUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_joyai_image + self.compilable_models = ["dit"] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + processor_config: ModelConfig = None, + vram_limit: float = None, + ): + pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder") + pipe.dit = model_pool.fetch_model("joyai_image_dit") + pipe.vae = model_pool.fetch_model("wan_video_vae") + + if processor_config is not None: + processor_config.download_if_necessary() + from transformers import AutoProcessor + pipe.processor = AutoProcessor.from_pretrained(processor_config.path) + + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + # ============================================================ + # __call__ — Orchestration only + # ============================================================ + @torch.no_grad() + def __call__( + self, + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 5.0, + input_image: Image.Image = None, + edit_images: Union[Image.Image, List[Image.Image]] = None, + edit_image_basesize: int = 1024, + denoising_strength: float = 1.0, + height: int = 1024, + width: int = 1024, + seed: int = None, + max_sequence_length: int = 4096, + num_inference_steps: int = 30, + tiled: Optional[bool] = False, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + shift: Optional[float] = 4.0, + progress_bar_cmd=tqdm, + ): + # 1. Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift) + + # 2. Three dictionaries + inputs_posi = {"prompt": prompt, "positive": True} + inputs_nega = {"negative_prompt": negative_prompt, "positive": True} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, + "edit_images": edit_images, "edit_image_basesize": edit_image_basesize, + "denoising_strength": denoising_strength, + "height": height, + "width": width, + "seed": seed, + "max_sequence_length": max_sequence_length, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + } + + # 3. Unit chain + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner( + unit, self, inputs_shared, inputs_posi, inputs_nega + ) + + # 4. Denoise loop + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # 5. VAE decode + self.load_models_to_device(['vae']) + latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w") + image = self.vae.decode(latents, device=self.device)[0] + image = self.vae_output_to_image(image, pattern="C 1 H W") + self.load_models_to_device([]) + return image + + +# ============================================================ +# PipelineUnits +# ============================================================ +class JoyAIImageUnit_ShapeChecker(PipelineUnit): + """Validates height/width divisible by 16.""" + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: "JoyAIImagePipeline", height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class JoyAIImageUnit_PromptEmbedder(PipelineUnit): + prompt_template_encode = { + 'image': + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + 'multiple_images': + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n", + 'video': + "<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + } + prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91} + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + input_params=("edit_images", "max_sequence_length"), + output_params=("prompt_embeds", "prompt_embeds_mask"), + onload_model_names=("joyai_image_text_encoder",), + ) + + def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_images, max_sequence_length): + pipe.load_models_to_device(self.onload_model_names) + + has_image = edit_images is not None + + if has_image: + prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_images, max_sequence_length) + else: + prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length) + + return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask} + + def _encode_with_image(self, pipe, prompt, edit_images, max_sequence_length): + template = self.prompt_template_encode['multiple_images'] + drop_idx = self.prompt_template_encode_start_idx['multiple_images'] + + image_tokens = '\n' + prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n" + prompt = prompt.replace('\n', '<|vision_start|><|image_pad|><|vision_end|>') + prompt = template.format(prompt) + inputs = pipe.processor(text=[prompt], images=edit_images, padding=True, return_tensors="pt").to(pipe.device) + encoder_hidden_states = pipe.text_encoder(**inputs, output_hidden_states=True) + last_hidden_states = encoder_hidden_states.hidden_states[-1] + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def _encode_text_only(self, pipe, prompt, max_sequence_length): + # TODO: + template = "system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:\n{}assistant\n" + drop_idx = 34 + + txt = template.format(prompt) + txt_tokens = pipe.processor.tokenizer( + [txt], max_length=max_sequence_length + drop_idx, + padding=True, truncation=True, return_tensors="pt" + ).to(pipe.device) + + hidden_states = pipe.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + )[-1] + + bool_mask = txt_tokens.attention_mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) + split_hidden = [e[drop_idx:] for e in split_hidden] + attn_masks = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden] + + max_seq_len = min(max_sequence_length, max(u.size(0) for u in split_hidden)) + prompt_embeds = torch.stack([ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden + ]) + encoder_attention_mask = torch.stack([ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_masks + ]) + + return prompt_embeds, encoder_attention_mask + + +class JoyAIImageUnit_EditImageEmbedder(PipelineUnit): + """ + """ + def __init__(self): + super().__init__( + input_params=("edit_images", "tiled", "tile_size", "tile_stride", "edit_image_basesize"), + output_params=("ref_latents", "num_items", "is_multi_item"), + onload_model_names=("wan_video_vae",), + ) + + def process(self, pipe: "JoyAIImagePipeline", edit_images, tiled, tile_size, tile_stride, edit_image_basesize=1024): + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_images, Image.Image): + edit_images = [edit_images] + assert len(edit_images) == 1, "Currently only supports single edit image for reference. Multiple edit images will be supported in the future." + edit_images = [_dynamic_resize_from_bucket(img, basesize=edit_image_basesize) for img in edit_images] + + images = [pipe.preprocess_image(img).transpose(0, 1) for img in edit_images] + latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(edit_images))).to(device=pipe.device, dtype=pipe.torch_dtype) + + return {"ref_latents": ref_vae, "edit_images": edit_images} + + +class JoyAIImageUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("seed", "height", "width", "rand_device"), + output_params=("noise"), + ) + def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device): + latent_h = height // pipe.vae.upsampling_factor + latent_w = width // pipe.vae.upsampling_factor + shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class JoyAIImageUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",), + ) + + def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + if input_image is None: + return {"latents": noise} + raise NotImplementedError("Input image to latents is not implemented yet. Currently only supports noise initialization when input_image is None.") + +# ============================================================ +# model_fn — DiT forward call +# ============================================================ +def model_fn_joyai_image( + dit, + latents, + timestep, + prompt_embeds, + prompt_embeds_mask, + ref_latents=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + + img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents + + img, _ = dit( + hidden_states=img, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + img = img[:, -latents.size(1):] + return img diff --git a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py new file mode 100644 index 000000000..e06f126b0 --- /dev/null +++ b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py @@ -0,0 +1,44 @@ +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="transformer/transformer.pth", + ), + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/model*.safetensors", + ), + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="vae/Wan2.1_VAE.pth", + ), + ], + processor_config=ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/", + ), +) +pipe.eval() +# Image editing +prompt = "Turn the plate blue" +# Replace with your input image path +input_image = Image.open("/mnt/nas1/zhanghong/project26/main_project/opencode/packages/joyai-image/JoyAI-Image/test_images/test_1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[input_image], + edit_image_basesize=1024, + height=1024, + width=1024, + seed=1, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit.png") diff --git a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py new file mode 100644 index 000000000..98ac5f6a2 --- /dev/null +++ b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py @@ -0,0 +1,52 @@ +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="transformer/transformer.pth", + ), + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/model*.safetensors", + ), + ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="vae/Wan2.1_VAE.pth", + ), + ], + processor_config=ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/", + ), + vram_limit=0.8, +) + +prompt = "Turn the plate blue" +input_image = None # Image.open("input.jpg").convert("RGB") + +if input_image is not None: + output = pipe( + prompt=prompt, + input_image=input_image, + denoising_strength=1.0, + seed=42, + num_inference_steps=50, + cfg_scale=5.0, + ) +else: + output = pipe( + prompt=prompt, + seed=42, + num_inference_steps=50, + cfg_scale=5.0, + height=1024, + width=1024, + ) + +output.save("output_joyai_edit_low_vram.png") +print("Saved output_joyai_edit_low_vram.png") diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh new file mode 100644 index 000000000..111649f43 --- /dev/null +++ b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh @@ -0,0 +1,17 @@ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml examples/joyai_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ + --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ + --data_file_keys "image,input_image" \ + --extra_inputs "input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml b/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml new file mode 100644 index 000000000..8a75f3d91 --- /dev/null +++ b/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: 'cpu' + offload_param_device: 'cpu' + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh new file mode 100644 index 000000000..fdb164344 --- /dev/null +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh @@ -0,0 +1,19 @@ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config.yaml examples/joyai_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ + --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ + --data_file_keys "image,input_image" \ + --extra_inputs "input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit_lora" \ + --trainable_models "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/joyai_image/model_training/lora/accelerate_config.yaml b/examples/joyai_image/model_training/lora/accelerate_config.yaml new file mode 100644 index 000000000..83280f73f --- /dev/null +++ b/examples/joyai_image/model_training/lora/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/joyai_image/model_training/train.py b/examples/joyai_image/model_training/train.py new file mode 100644 index 000000000..4beedbb2e --- /dev/null +++ b/examples/joyai_image/model_training/train.py @@ -0,0 +1,170 @@ +import torch, os, argparse, accelerate +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +from diffsynth.diffusion import * +from diffsynth.core.data.operators import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class JoyAIImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + processor_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + processor_config = ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") if processor_path is None else ModelConfig(processor_path) + self.pipe = JoyAIImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, processor_config=processor_config) + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, + } + # Handle input image for image editing + if "input_image" in data and data["input_image"] is not None: + if isinstance(data["input_image"], list): + inputs_shared.update({ + "input_image": data["input_image"], + "height": data["input_image"][0].size[1], + "width": data["input_image"][0].size[0], + }) + else: + inputs_shared.update({ + "input_image": data["input_image"], + "height": data["input_image"].size[1], + "width": data["input_image"].size[0], + }) + else: + inputs_shared.update({ + "input_image": None, + "height": 1024, + "width": 1024, + }) + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def joyai_image_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + return parser + + +if __name__ == "__main__": + parser = joyai_image_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ), + special_operator_map={ + "input_image": RouteByType(operator_map=[ + (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)), + (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))), + ]), + "image": RouteByType(operator_map=[ + (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)), + (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))), + ]), + } + ) + model = JoyAIImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + processor_path=args.processor_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py new file mode 100644 index 000000000..d9f88f7fa --- /dev/null +++ b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"), + ], + processor_config=ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/", + ), +) +state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) + +prompt = "Turn the plate blue" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, input_image=image, seed=0, num_inference_steps=50, cfg_scale=5.0) +image.save(f"image.jpg") diff --git a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py new file mode 100644 index 000000000..2dee12168 --- /dev/null +++ b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"), + ], + processor_config=ModelConfig( + model_id="jd-opensource/JoyAI-Image-Edit", + origin_file_pattern="JoyAI-Image-Und/", + ), +) +state_dict = load_state_dict("models/train/JoyAI-Image-Edit_lora/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict, strict=False) + +prompt = "Turn the plate blue" +image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, input_image=image, seed=0, num_inference_steps=50, cfg_scale=5.0) +image.save(f"image.jpg") From 3c27fd32e7ad42f37fe9563159f2c8db591077f6 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sun, 12 Apr 2026 20:26:12 +0800 Subject: [PATCH 3/8] train --- diffsynth/pipelines/joyai_image.py | 44 ++++++++++++++----- .../model_inference/JoyAI-Image-Edit.py | 3 -- .../full/JoyAI-Image-Edit-test.sh | 30 +++++++++++++ .../full/accelerate_config_single_gpu.yaml | 10 +++++ .../lora/JoyAI-Image-Edit-full-test.sh | 34 ++++++++++++++ .../lora/JoyAI-Image-Edit-test.sh | 35 +++++++++++++++ .../lora/accelerate_config_single_gpu.yaml | 22 ++++++++++ examples/joyai_image/model_training/train.py | 14 ++++-- .../validate_lora/JoyAI-Image-Edit.py | 8 ++-- test.py | 5 --- 10 files changed, 178 insertions(+), 27 deletions(-) create mode 100644 examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh create mode 100644 examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml create mode 100644 examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh create mode 100644 examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh create mode 100644 examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml delete mode 100644 test.py diff --git a/diffsynth/pipelines/joyai_image.py b/diffsynth/pipelines/joyai_image.py index 13e6999ca..d1e45cbf6 100644 --- a/diffsynth/pipelines/joyai_image.py +++ b/diffsynth/pipelines/joyai_image.py @@ -220,13 +220,15 @@ def _encode_text_only(self, pipe, prompt, max_sequence_length): input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True, - )[-1] + ).hidden_states[-1] bool_mask = txt_tokens.attention_mask.bool() - valid_lengths = bool_mask.sum(dim=1) - selected = hidden_states[bool_mask] - split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) - split_hidden = [e[drop_idx:] for e in split_hidden] + split_hidden = [] + for i in range(hidden_states.size(0)): + mask = bool_mask[i] + valid_states = hidden_states[i][mask] + valid_states = valid_states[drop_idx:] + split_hidden.append(valid_states) attn_masks = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden] max_seq_len = min(max_sequence_length, max(u.size(0) for u in split_hidden)) @@ -242,25 +244,29 @@ def _encode_text_only(self, pipe, prompt, max_sequence_length): class JoyAIImageUnit_EditImageEmbedder(PipelineUnit): """ + Encodes edit images into reference latents using VAE. """ def __init__(self): super().__init__( - input_params=("edit_images", "tiled", "tile_size", "tile_stride", "edit_image_basesize"), + input_params=("edit_images", "tiled", "tile_size", "tile_stride", "edit_image_basesize", "height", "width"), output_params=("ref_latents", "num_items", "is_multi_item"), onload_model_names=("wan_video_vae",), ) - def process(self, pipe: "JoyAIImagePipeline", edit_images, tiled, tile_size, tile_stride, edit_image_basesize=1024): + def process(self, pipe: "JoyAIImagePipeline", edit_images, tiled, tile_size, tile_stride, edit_image_basesize, height, width): pipe.load_models_to_device(self.onload_model_names) + if edit_images is None: + return {} if isinstance(edit_images, Image.Image): edit_images = [edit_images] assert len(edit_images) == 1, "Currently only supports single edit image for reference. Multiple edit images will be supported in the future." - edit_images = [_dynamic_resize_from_bucket(img, basesize=edit_image_basesize) for img in edit_images] + # Resize edit images to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents + edit_images = [img.resize((width, height), Image.LANCZOS) for img in edit_images] images = [pipe.preprocess_image(img).transpose(0, 1) for img in edit_images] latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(edit_images))).to(device=pipe.device, dtype=pipe.torch_dtype) - + return {"ref_latents": ref_vae, "edit_images": edit_images} @@ -281,13 +287,29 @@ def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device): class JoyAIImageUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + input_params=("input_image", "image", "noise", "tiled", "tile_size", "tile_stride"), output_params=("latents", "input_latents"), onload_model_names=("vae",), ) - def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride): + def process(self, pipe: JoyAIImagePipeline, input_image, image, noise, tiled, tile_size, tile_stride): + pipe.load_models_to_device(self.onload_model_names) if input_image is None: + # Training mode: VAE-encode ground truth image to get input_latents + if image is not None: + if isinstance(image, list): + image = image[0] + # Derive target image size from noise shape to ensure latent compatibility + # noise shape: [b, n, c, f, h, w] + latent_h = noise.shape[-2] + latent_w = noise.shape[-1] + img_h = latent_h * pipe.vae.upsampling_factor + img_w = latent_w * pipe.vae.upsampling_factor + image = image.resize((img_w, img_h), Image.LANCZOS) + img_tensor = [pipe.preprocess_image(image).transpose(0, 1)] + input_latents = pipe.vae.encode(img_tensor, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + input_latents = rearrange(input_latents, "(b n) c f h w -> b n c f h w", n=1) + return {"latents": noise, "input_latents": input_latents} return {"latents": noise} raise NotImplementedError("Input image to latents is not implemented yet. Currently only supports noise initialization when input_image is None.") diff --git a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py index e06f126b0..6285fd220 100644 --- a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py @@ -24,10 +24,7 @@ origin_file_pattern="JoyAI-Image-Und/", ), ) -pipe.eval() -# Image editing prompt = "Turn the plate blue" -# Replace with your input image path input_image = Image.open("/mnt/nas1/zhanghong/project26/main_project/opencode/packages/joyai-image/JoyAI-Image/test_images/test_1.jpg").convert("RGB") output = pipe( diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh new file mode 100644 index 000000000..5ddbbf315 --- /dev/null +++ b/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Minimal training smoke test for JoyAI-Image-Edit +# Runs only 0.01 epochs (1-2 steps) to verify training pipeline works + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +cd "$DIFFSYNTH_ROOT" + +export CONDA_DEFAULT_ENV=joyai-image-diffsynth +export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" + +accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml \ + examples/joyai_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ + --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ + --data_file_keys "image,input_image" \ + --extra_inputs "input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit_full_test" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml b/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml new file mode 100644 index 000000000..3947d488c --- /dev/null +++ b/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml @@ -0,0 +1,10 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +use_cpu: false diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh new file mode 100644 index 000000000..aa6615cad --- /dev/null +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# LoRA training for JoyAI-Image-Edit — full validation run (5 epochs, 1024x1024) + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +cd "$DIFFSYNTH_ROOT" + +export CONDA_DEFAULT_ENV=joyai-image-diffsynth +export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" + +accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml \ + examples/joyai_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ + --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ + --data_file_keys "image,input_image" \ + --extra_inputs "input_image" \ + --max_pixels 1048576 \ + --height 1024 \ + --width 1024 \ + --dataset_repeat 10 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit_lora" \ + --trainable_models "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --initialize_model_on_cpu \ + --find_unused_parameters diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh new file mode 100644 index 000000000..eef77d345 --- /dev/null +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Minimal LoRA training smoke test for JoyAI-Image-Edit +# Runs only 1 epoch with minimal data to verify training pipeline works + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +cd "$DIFFSYNTH_ROOT" + +export CONDA_DEFAULT_ENV=joyai-image-diffsynth +export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" + +accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml \ + examples/joyai_image/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ + --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ + --data_file_keys "image,input_image" \ + --extra_inputs "input_image" \ + --max_pixels 262144 \ + --height 256 \ + --width 256 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit_lora_test" \ + --trainable_models "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out" \ + --lora_rank 4 \ + --use_gradient_checkpointing \ + --initialize_model_on_cpu \ + --find_unused_parameters diff --git a/examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml b/examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml new file mode 100644 index 000000000..a5bae8c59 --- /dev/null +++ b/examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: cpu + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/joyai_image/model_training/train.py b/examples/joyai_image/model_training/train.py index 4beedbb2e..5b4f88a7d 100644 --- a/examples/joyai_image/model_training/train.py +++ b/examples/joyai_image/model_training/train.py @@ -63,28 +63,34 @@ def get_pipeline_inputs(self, data): "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "edit_image_auto_resize": True, + "edit_image_basesize": 1024, } - # Handle input image for image editing + # Handle input image for image editing — maps to edit_images in the pipeline if "input_image" in data and data["input_image"] is not None: if isinstance(data["input_image"], list): inputs_shared.update({ - "input_image": data["input_image"], + "edit_images": data["input_image"], "height": data["input_image"][0].size[1], "width": data["input_image"][0].size[0], }) else: inputs_shared.update({ - "input_image": data["input_image"], + "edit_images": data["input_image"], "height": data["input_image"].size[1], "width": data["input_image"].size[0], }) else: inputs_shared.update({ - "input_image": None, + "edit_images": None, "height": 1024, "width": 1024, }) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + # input_image is None for training (latents initialized from noise) + inputs_shared["input_image"] = None + # Ground truth image for computing input_latents in loss computation + if "image" in data and data["image"] is not None: + inputs_shared["image"] = data["image"] return inputs_shared, inputs_posi, inputs_nega def forward(self, data, inputs=None): diff --git a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py index 2dee12168..e1461770b 100644 --- a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py @@ -16,10 +16,10 @@ origin_file_pattern="JoyAI-Image-Und/", ), ) -state_dict = load_state_dict("models/train/JoyAI-Image-Edit_lora/epoch-1.safetensors") +state_dict = load_state_dict("models/train/JoyAI-Image-Edit_lora/epoch-4.safetensors") pipe.dit.load_state_dict(state_dict, strict=False) prompt = "Turn the plate blue" -image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) -image = pipe(prompt, input_image=image, seed=0, num_inference_steps=50, cfg_scale=5.0) -image.save(f"image.jpg") +image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").resize((1024, 1024)) +image = pipe(prompt, edit_images=image, seed=0, num_inference_steps=50, cfg_scale=5.0) +image.save(f"models/train/JoyAI-Image-Edit_lora/val_epoch-4.jpg") diff --git a/test.py b/test.py deleted file mode 100644 index de2b2b0e0..000000000 --- a/test.py +++ /dev/null @@ -1,5 +0,0 @@ -from diffsynth.models.model_loader import ModelPool - -pool = ModelPool() -pool.auto_load_model("models/jd-opensource/JoyAI-Image-Edit/vae/Wan2.1_VAE.pth") -model = pool.fetch_model("wan_video_vae") \ No newline at end of file From d51fabb2847b71efa40a39c525bbac6b361dd449 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 14 Apr 2026 11:17:24 +0800 Subject: [PATCH 4/8] ready --- .../configs/vram_management_module_maps.py | 16 ++++ diffsynth/models/joyai_image_dit.py | 2 +- diffsynth/pipelines/joyai_image.py | 75 ++++--------------- .../model_inference/JoyAI-Image-Edit.py | 18 ++++- .../JoyAI-Image-Edit.py | 58 ++++++++------ .../full/JoyAI-Image-Edit-test.sh | 30 -------- .../model_training/full/JoyAI-Image-Edit.sh | 58 ++++++++++---- .../full/accelerate_config_single_gpu.yaml | 10 --- .../full/accelerate_config_zero2offload.yaml | 22 ------ .../accelerate_config_zero3.yaml} | 7 +- .../lora/JoyAI-Image-Edit-full-test.sh | 34 --------- .../lora/JoyAI-Image-Edit-test.sh | 35 --------- .../model_training/lora/JoyAI-Image-Edit.sh | 48 +++++++----- .../lora/accelerate_config.yaml | 22 ------ examples/joyai_image/model_training/train.py | 50 ++----------- .../validate_full/JoyAI-Image-Edit.py | 20 ++++- .../validate_lora/JoyAI-Image-Edit.py | 24 ++++-- 17 files changed, 198 insertions(+), 331 deletions(-) delete mode 100644 examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh delete mode 100644 examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml delete mode 100644 examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml rename examples/joyai_image/model_training/{lora/accelerate_config_single_gpu.yaml => full/accelerate_config_zero3.yaml} (82%) delete mode 100644 examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh delete mode 100644 examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh delete mode 100644 examples/joyai_image/model_training/lora/accelerate_config.yaml diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index de276891f..50cf5a72e 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -267,6 +267,22 @@ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.joyai_image_dit.Transformer3DModel": { + "diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } def QwenImageTextEncoder_Module_Map_Updater(): diff --git a/diffsynth/models/joyai_image_dit.py b/diffsynth/models/joyai_image_dit.py index 23de2c03a..b025b9b04 100644 --- a/diffsynth/models/joyai_image_dit.py +++ b/diffsynth/models/joyai_image_dit.py @@ -632,7 +632,7 @@ def forward( if num_items > 1: img = torch.cat([img[:, 1:], img[:, :1]], dim=1) - return (img, txt) + return img def unpatchify(self, x, t, h, w): c = self.out_channels diff --git a/diffsynth/pipelines/joyai_image.py b/diffsynth/pipelines/joyai_image.py index d1e45cbf6..47e597eea 100644 --- a/diffsynth/pipelines/joyai_image.py +++ b/diffsynth/pipelines/joyai_image.py @@ -78,7 +78,6 @@ def __call__( cfg_scale: float = 5.0, input_image: Image.Image = None, edit_images: Union[Image.Image, List[Image.Image]] = None, - edit_image_basesize: int = 1024, denoising_strength: float = 1.0, height: int = 1024, width: int = 1024, @@ -95,12 +94,12 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift) # 2. Three dictionaries - inputs_posi = {"prompt": prompt, "positive": True} - inputs_nega = {"negative_prompt": negative_prompt, "positive": True} + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} inputs_shared = { "cfg_scale": cfg_scale, "input_image": input_image, - "edit_images": edit_images, "edit_image_basesize": edit_image_basesize, + "edit_images": edit_images, "denoising_strength": denoising_strength, "height": height, "width": width, @@ -206,39 +205,8 @@ def _encode_with_image(self, pipe, prompt, edit_images, max_sequence_length): return prompt_embeds, prompt_embeds_mask def _encode_text_only(self, pipe, prompt, max_sequence_length): - # TODO: - template = "system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:\n{}assistant\n" - drop_idx = 34 - - txt = template.format(prompt) - txt_tokens = pipe.processor.tokenizer( - [txt], max_length=max_sequence_length + drop_idx, - padding=True, truncation=True, return_tensors="pt" - ).to(pipe.device) - - hidden_states = pipe.text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask, - output_hidden_states=True, - ).hidden_states[-1] - - bool_mask = txt_tokens.attention_mask.bool() - split_hidden = [] - for i in range(hidden_states.size(0)): - mask = bool_mask[i] - valid_states = hidden_states[i][mask] - valid_states = valid_states[drop_idx:] - split_hidden.append(valid_states) - attn_masks = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden] - - max_seq_len = min(max_sequence_length, max(u.size(0) for u in split_hidden)) - prompt_embeds = torch.stack([ - torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden - ]) - encoder_attention_mask = torch.stack([ - torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_masks - ]) - + # TODO: may support for text-only encoding in the future. + raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_images for now.") return prompt_embeds, encoder_attention_mask @@ -254,15 +222,14 @@ def __init__(self): ) def process(self, pipe: "JoyAIImagePipeline", edit_images, tiled, tile_size, tile_stride, edit_image_basesize, height, width): - pipe.load_models_to_device(self.onload_model_names) if edit_images is None: return {} if isinstance(edit_images, Image.Image): edit_images = [edit_images] + pipe.load_models_to_device(self.onload_model_names) assert len(edit_images) == 1, "Currently only supports single edit image for reference. Multiple edit images will be supported in the future." # Resize edit images to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents edit_images = [img.resize((width, height), Image.LANCZOS) for img in edit_images] - images = [pipe.preprocess_image(img).transpose(0, 1) for img in edit_images] latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(edit_images))).to(device=pipe.device, dtype=pipe.torch_dtype) @@ -287,31 +254,21 @@ def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device): class JoyAIImageUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_image", "image", "noise", "tiled", "tile_size", "tile_stride"), + input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), output_params=("latents", "input_latents"), onload_model_names=("vae",), ) - def process(self, pipe: JoyAIImagePipeline, input_image, image, noise, tiled, tile_size, tile_stride): - pipe.load_models_to_device(self.onload_model_names) + def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride): if input_image is None: - # Training mode: VAE-encode ground truth image to get input_latents - if image is not None: - if isinstance(image, list): - image = image[0] - # Derive target image size from noise shape to ensure latent compatibility - # noise shape: [b, n, c, f, h, w] - latent_h = noise.shape[-2] - latent_w = noise.shape[-1] - img_h = latent_h * pipe.vae.upsampling_factor - img_w = latent_w * pipe.vae.upsampling_factor - image = image.resize((img_w, img_h), Image.LANCZOS) - img_tensor = [pipe.preprocess_image(image).transpose(0, 1)] - input_latents = pipe.vae.encode(img_tensor, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - input_latents = rearrange(input_latents, "(b n) c f h w -> b n c f h w", n=1) - return {"latents": noise, "input_latents": input_latents} return {"latents": noise} - raise NotImplementedError("Input image to latents is not implemented yet. Currently only supports noise initialization when input_image is None.") + pipe.load_models_to_device(self.onload_model_names) + if isinstance(input_image, Image.Image): + input_image = [input_image] + input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image] + latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image))) + return {"latents": noise, "input_latents": input_latents} # ============================================================ # model_fn — DiT forward call @@ -330,7 +287,7 @@ def model_fn_joyai_image( img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents - img, _ = dit( + img = dit( hidden_states=img, timestep=timestep, encoder_hidden_states=prompt_embeds, diff --git a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py index 6285fd220..6c89bd6d0 100644 --- a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py @@ -1,6 +1,14 @@ from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig import torch from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) pipe = JoyAIImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -24,13 +32,15 @@ origin_file_pattern="JoyAI-Image-Und/", ), ) -prompt = "Turn the plate blue" -input_image = Image.open("/mnt/nas1/zhanghong/project26/main_project/opencode/packages/joyai-image/JoyAI-Image/test_images/test_1.jpg").convert("RGB") + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[input_image], - edit_image_basesize=1024, + edit_images=[edit_images], height=1024, width=1024, seed=1, diff --git a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py index 98ac5f6a2..e5c9fc451 100644 --- a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py @@ -1,6 +1,25 @@ from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig import torch from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} pipe = JoyAIImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -9,44 +28,39 @@ ModelConfig( model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", + **vram_config, ), ModelConfig( model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", + **vram_config, ), ModelConfig( model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", + **vram_config, ), ], processor_config=ModelConfig( model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/", ), - vram_limit=0.8, + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) -prompt = "Turn the plate blue" -input_image = None # Image.open("input.jpg").convert("RGB") +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") -if input_image is not None: - output = pipe( - prompt=prompt, - input_image=input_image, - denoising_strength=1.0, - seed=42, - num_inference_steps=50, - cfg_scale=5.0, - ) -else: - output = pipe( - prompt=prompt, - seed=42, - num_inference_steps=50, - cfg_scale=5.0, - height=1024, - width=1024, - ) +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) output.save("output_joyai_edit_low_vram.png") -print("Saved output_joyai_edit_low_vram.png") diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh deleted file mode 100644 index 5ddbbf315..000000000 --- a/examples/joyai_image/model_training/full/JoyAI-Image-Edit-test.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# Minimal training smoke test for JoyAI-Image-Edit -# Runs only 0.01 epochs (1-2 steps) to verify training pipeline works - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" - -cd "$DIFFSYNTH_ROOT" - -export CONDA_DEFAULT_ENV=joyai-image-diffsynth -export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" -export CUDA_VISIBLE_DEVICES=0 -export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" - -accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml \ - examples/joyai_image/model_training/train.py \ - --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ - --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ - --data_file_keys "image,input_image" \ - --extra_inputs "input_image" \ - --max_pixels 1048576 \ - --dataset_repeat 1 \ - --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-5 \ - --num_epochs 1 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/JoyAI-Image-Edit_full_test" \ - --trainable_models "dit" \ - --use_gradient_checkpointing \ - --find_unused_parameters diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh index 111649f43..59730749a 100644 --- a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh @@ -1,17 +1,43 @@ -modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset +# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/ +# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset -accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml examples/joyai_image/model_training/train.py \ - --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ - --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ - --data_file_keys "image,input_image" \ - --extra_inputs "input_image" \ - --max_pixels 1048576 \ - --dataset_repeat 50 \ - --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-5 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/JoyAI-Image-Edit_full" \ - --trainable_models "dit" \ - --use_gradient_checkpointing \ - --find_unused_parameters +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" +cd "$DIFFSYNTH_ROOT" + +export CONDA_DEFAULT_ENV=joyai-image-diffsynth +export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" + +# ==================== 第一阶段:前处理计算 ==================== +# 加载 VAE + TextEncoder,缓存编码结果到硬盘 +accelerate launch examples/joyai_image/model_training/train.py \ + --dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \ + --dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \ + --max_pixels 1048576 --dataset_repeat 1 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-5 --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-full-cache" \ + --use_gradient_checkpointing --find_unused_parameters \ + --data_file_keys "image,edit_images" \ + --extra_inputs "edit_images" \ + --task "sft:data_process" + +# ==================== 第二阶段:DiT 训练 ==================== +# 从缓存读取,仅训练 DiT +# NOTE: Full training of 16B DiT requires DeepSpeed ZeRO-3 with multiple GPUs. +# This script uses single GPU config. If OOM occurs, install DeepSpeed and use +# accelerate_config_zero3.yaml instead. +accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml \ + examples/joyai_image/model_training/train.py \ + --dataset_base_path "./models/train/JoyAI-Image-Edit-full-cache" \ + --max_pixels 1048576 --dataset_repeat 50 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \ + --learning_rate 1e-5 --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing --find_unused_parameters \ + --data_file_keys "image,edit_images" \ + --extra_inputs "edit_images" \ + --task "sft:train" diff --git a/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml b/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml deleted file mode 100644 index 3947d488c..000000000 --- a/examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml +++ /dev/null @@ -1,10 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: 'NO' -downcast_bf16: 'no' -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -use_cpu: false diff --git a/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml b/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml deleted file mode 100644 index 8a75f3d91..000000000 --- a/examples/joyai_image/model_training/full/accelerate_config_zero2offload.yaml +++ /dev/null @@ -1,22 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - gradient_accumulation_steps: 1 - offload_optimizer_device: 'cpu' - offload_param_device: 'cpu' - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml b/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml similarity index 82% rename from examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml rename to examples/joyai_image/model_training/full/accelerate_config_zero3.yaml index a5bae8c59..e7ae29c07 100644 --- a/examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml +++ b/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml @@ -4,8 +4,9 @@ deepspeed_config: gradient_accumulation_steps: 1 offload_optimizer_device: cpu offload_param_device: none - zero3_init_flag: false - zero_stage: 2 + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' enable_cpu_affinity: false @@ -13,7 +14,7 @@ machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 -num_processes: 1 +num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh deleted file mode 100644 index aa6615cad..000000000 --- a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-full-test.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# LoRA training for JoyAI-Image-Edit — full validation run (5 epochs, 1024x1024) - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" - -cd "$DIFFSYNTH_ROOT" - -export CONDA_DEFAULT_ENV=joyai-image-diffsynth -export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" -export CUDA_VISIBLE_DEVICES=0 -export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" - -accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml \ - examples/joyai_image/model_training/train.py \ - --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ - --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ - --data_file_keys "image,input_image" \ - --extra_inputs "input_image" \ - --max_pixels 1048576 \ - --height 1024 \ - --width 1024 \ - --dataset_repeat 10 \ - --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ - --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/JoyAI-Image-Edit_lora" \ - --trainable_models "dit" \ - --lora_target_modules "to_q,to_k,to_v,to_out" \ - --lora_rank 32 \ - --use_gradient_checkpointing \ - --initialize_model_on_cpu \ - --find_unused_parameters diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh deleted file mode 100644 index eef77d345..000000000 --- a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit-test.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Minimal LoRA training smoke test for JoyAI-Image-Edit -# Runs only 1 epoch with minimal data to verify training pipeline works - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" - -cd "$DIFFSYNTH_ROOT" - -export CONDA_DEFAULT_ENV=joyai-image-diffsynth -export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" -export CUDA_VISIBLE_DEVICES=0 -export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" - -accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config_single_gpu.yaml \ - examples/joyai_image/model_training/train.py \ - --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ - --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ - --data_file_keys "image,input_image" \ - --extra_inputs "input_image" \ - --max_pixels 262144 \ - --height 256 \ - --width 256 \ - --dataset_repeat 1 \ - --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ - --num_epochs 1 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/JoyAI-Image-Edit_lora_test" \ - --trainable_models "dit" \ - --lora_target_modules "to_q,to_k,to_v,to_out" \ - --lora_rank 4 \ - --use_gradient_checkpointing \ - --initialize_model_on_cpu \ - --find_unused_parameters diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh index fdb164344..a461a2fd6 100644 --- a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh @@ -1,19 +1,31 @@ -modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset +# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/ +# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset -accelerate launch --config_file examples/joyai_image/model_training/lora/accelerate_config.yaml examples/joyai_image/model_training/train.py \ - --dataset_base_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit \ - --dataset_metadata_path data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv \ - --data_file_keys "image,input_image" \ - --extra_inputs "input_image" \ - --max_pixels 1048576 \ - --dataset_repeat 50 \ - --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth,jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-4 \ - --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." \ - --output_path "./models/train/JoyAI-Image-Edit_lora" \ - --trainable_models "dit" \ - --lora_target_modules "to_q,to_k,to_v,to_out" \ - --lora_rank 32 \ - --use_gradient_checkpointing \ - --find_unused_parameters +# ==================== 第一阶段:前处理计算 ==================== +# 加载 VAE + TextEncoder,缓存编码结果到硬盘 +accelerate launch examples/joyai_image/model_training/train.py \ + --dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \ + --dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \ + --max_pixels 1048576 --dataset_repeat 1 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ + --learning_rate 1e-4 --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-split-cache" \ + --lora_base_model "dit" --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" --lora_rank 32 \ + --use_gradient_checkpointing --find_unused_parameters \ + --data_file_keys "image,edit_images" \ + --extra_inputs "edit_images" \ + --task "sft:data_process" + +# ==================== 第二阶段:DiT 训练 ==================== +# 从缓存读取,仅训练 DiT +accelerate launch examples/joyai_image/model_training/train.py \ + --dataset_base_path "./models/train/JoyAI-Image-Edit-split-cache" \ + --max_pixels 1048576 --dataset_repeat 50 \ + --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \ + --learning_rate 1e-4 --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-lora" \ + --lora_base_model "dit" --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" --lora_rank 32 \ + --use_gradient_checkpointing --find_unused_parameters \ + --data_file_keys "image,edit_images" \ + --extra_inputs "edit_images" \ + --task "sft:train" diff --git a/examples/joyai_image/model_training/lora/accelerate_config.yaml b/examples/joyai_image/model_training/lora/accelerate_config.yaml deleted file mode 100644 index 83280f73f..000000000 --- a/examples/joyai_image/model_training/lora/accelerate_config.yaml +++ /dev/null @@ -1,22 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - gradient_accumulation_steps: 1 - offload_optimizer_device: none - offload_param_device: none - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/joyai_image/model_training/train.py b/examples/joyai_image/model_training/train.py index 5b4f88a7d..6c4fd8f30 100644 --- a/examples/joyai_image/model_training/train.py +++ b/examples/joyai_image/model_training/train.py @@ -45,52 +45,27 @@ def __init__( self.task = task self.task_to_loss = { "sft:data_process": lambda pipe, *args: args, - "direct_distill:data_process": lambda pipe, *args: args, "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), - "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), - "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), } def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} inputs_nega = {"negative_prompt": ""} inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, - "edit_image_auto_resize": True, - "edit_image_basesize": 1024, } - # Handle input image for image editing — maps to edit_images in the pipeline - if "input_image" in data and data["input_image"] is not None: - if isinstance(data["input_image"], list): - inputs_shared.update({ - "edit_images": data["input_image"], - "height": data["input_image"][0].size[1], - "width": data["input_image"][0].size[0], - }) - else: - inputs_shared.update({ - "edit_images": data["input_image"], - "height": data["input_image"].size[1], - "width": data["input_image"].size[0], - }) - else: - inputs_shared.update({ - "edit_images": None, - "height": 1024, - "width": 1024, - }) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) - # input_image is None for training (latents initialized from noise) - inputs_shared["input_image"] = None - # Ground truth image for computing input_latents in loss computation - if "image" in data and data["image"] is not None: - inputs_shared["image"] = data["image"] return inputs_shared, inputs_posi, inputs_nega def forward(self, data, inputs=None): @@ -103,7 +78,7 @@ def forward(self, data, inputs=None): def joyai_image_parser(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = argparse.ArgumentParser(description="JoyAI-Image training.") parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.") @@ -131,16 +106,6 @@ def joyai_image_parser(): height_division_factor=16, width_division_factor=16, ), - special_operator_map={ - "input_image": RouteByType(operator_map=[ - (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)), - (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))), - ]), - "image": RouteByType(operator_map=[ - (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)), - (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))), - ]), - } ) model = JoyAIImageTrainingModule( model_paths=args.model_paths, @@ -167,10 +132,7 @@ def joyai_image_parser(): ) launcher_map = { "sft:data_process": launch_data_process_task, - "direct_distill:data_process": launch_data_process_task, "sft": launch_training_task, "sft:train": launch_training_task, - "direct_distill": launch_training_task, - "direct_distill:train": launch_training_task, } launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py index d9f88f7fa..92deb79b2 100644 --- a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py @@ -16,10 +16,22 @@ origin_file_pattern="JoyAI-Image-Und/", ), ) + +# Load full training weights state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors") pipe.dit.load_state_dict(state_dict) -prompt = "Turn the plate blue" -image = Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)) -image = pipe(prompt, input_image=image, seed=0, num_inference_steps=50, cfg_scale=5.0) -image.save(f"image.jpg") +# Use training dataset prompt and edit_images +prompt = "将裙子改为粉色" +edit_images = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") + +image = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=50, + cfg_scale=5.0, +) +image.save("image_full.jpg") diff --git a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py index e1461770b..c215e4fc5 100644 --- a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py @@ -1,7 +1,6 @@ import torch from PIL import Image from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig -from diffsynth import load_state_dict pipe = JoyAIImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -16,10 +15,21 @@ origin_file_pattern="JoyAI-Image-Und/", ), ) -state_dict = load_state_dict("models/train/JoyAI-Image-Edit_lora/epoch-4.safetensors") -pipe.dit.load_state_dict(state_dict, strict=False) -prompt = "Turn the plate blue" -image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").resize((1024, 1024)) -image = pipe(prompt, edit_images=image, seed=0, num_inference_steps=50, cfg_scale=5.0) -image.save(f"models/train/JoyAI-Image-Edit_lora/val_epoch-4.jpg") +# Load LoRA weights from dual-stage training output +pipe.load_lora(pipe.dit, "models/train/JoyAI-Image-Edit-lora/epoch-4.safetensors") + +# Use training dataset prompt and edit_images +prompt = "将裙子改为粉色" +edit_images = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") + +image = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) +image.save("image_lora.jpg") From 016fbf620664703e5b49d6eda33a701078c1026f Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 14 Apr 2026 11:42:12 +0800 Subject: [PATCH 5/8] styling --- diffsynth/configs/model_configs.py | 15 +- diffsynth/models/joyai_image_common.py | 135 ------------------ diffsynth/pipelines/joyai_image.py | 47 +++--- .../model_inference/JoyAI-Image-Edit.py | 20 +-- .../JoyAI-Image-Edit.py | 23 +-- .../model_training/full/JoyAI-Image-Edit.sh | 42 +++--- .../full/accelerate_config_zero3.yaml | 2 +- .../model_training/lora/JoyAI-Image-Edit.sh | 36 +++-- .../validate_full/JoyAI-Image-Edit.py | 5 +- .../validate_lora/JoyAI-Image-Edit.py | 5 +- 10 files changed, 71 insertions(+), 259 deletions(-) delete mode 100644 diffsynth/models/joyai_image_common.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index e740849c5..4ca67ac7b 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -890,20 +890,7 @@ "model_hash": "56592ddfd7d0249d3aa527d24161a863", "model_name": "joyai_image_dit", "model_class": "diffsynth.models.joyai_image_dit.Transformer3DModel", - "extra_kwargs": { - "patch_size": [1, 2, 2], - "in_channels": 16, - "out_channels": 16, - "hidden_size": 4096, - "heads_num": 32, - "text_states_dim": 4096, - "mlp_width_ratio": 4.0, - "mm_double_blocks_depth": 40, - "rope_dim_list": [16, 56, 56], - "rope_type": "rope", - "dit_modulation_type": "wanx", - "theta": 10000, - }, + "extra_kwargs": {"patch_size": [1, 2, 2], "in_channels": 16, "out_channels": 16, "hidden_size": 4096, "heads_num": 32, "text_states_dim": 4096, "mlp_width_ratio": 4.0, "mm_double_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "rope_type": "rope", "dit_modulation_type": "wanx", "theta": 10000}, "state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_dit.JoyAIImageDiTStateDictConverter", }, { diff --git a/diffsynth/models/joyai_image_common.py b/diffsynth/models/joyai_image_common.py deleted file mode 100644 index 501a23360..000000000 --- a/diffsynth/models/joyai_image_common.py +++ /dev/null @@ -1,135 +0,0 @@ -from PIL import Image -from typing import Tuple -import math -import torchvision.transforms.functional as TF - -class BucketGroup: - """Manages dynamic batch grouping buckets for image inference.""" - - def __init__( - self, - bucket_configs: list[tuple[int, int, int, int, int]], - prioritize_frame_matching: bool = True, - ): - """ - Initialize bucket group with predefined configurations. - - Args: - bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples - prioritize_frame_matching: Unused, kept for API compatibility. - """ - self.bucket_configs = [tuple(b) for b in bucket_configs] - - def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: - """ - Find the best matching bucket for given media dimensions. - - Args: - media_shape: (num_items, num_frames, height, width) of input media - - Returns: - Best matching bucket as (batch_size, num_items, num_frames, height, width) - """ - num_items, num_frames, height, width = media_shape - target_aspect_ratio = height / width - - if num_frames != 1: - raise ValueError( - f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}") - - valid_buckets = [ - b for b in self.bucket_configs - if b[1] == num_items and b[2] == 1 - ] - if not valid_buckets: - raise ValueError( - f"No image buckets found for shape {media_shape}") - - return min( - valid_buckets, - key=lambda bucket: abs( - (bucket[3] / bucket[4]) - target_aspect_ratio) - ) - - def __repr__(self) -> str: - return ( - f"BucketGroup(" - f"total_buckets={len(self.bucket_configs)}, " - f"configs={self.bucket_configs})" - ) - - -def _generate_hw_buckets(base_height=256, base_width=256, step_width=16, step_height=16, max_ratio=4.0) -> list[tuple[int, int, int, int, int]]: - """Generate dimension buckets based on aspect ratios.""" - buckets = [] - target_pixels = base_height * base_width - - height = target_pixels // step_width - width = step_width - - while height >= step_height: - if max(height, width) / min(height, width) <= max_ratio: - buckets.append((1, 1, 1, height, width)) - if height * (width + step_width) <= target_pixels: - width += step_width - else: - height -= step_height - - return buckets - - -def generate_video_image_bucket(basesize=256, min_temporal=65, max_temporal=129, bs_img=8, bs_vid=1, bs_mimg=4, min_items=1, max_items=1): - """Generate bucket configs for image inference. - - Returns: - List of (batch_size, num_items, num_frames, height, width) tuples. - """ - assert basesize in [ - 256, 512, 768, 1024], f"[generate_video_image_bucket] wrong basesize {basesize}" - bucket_list = [] - - base_bucket_list = _generate_hw_buckets() - # image - for _bucket in base_bucket_list: - bucket = list(_bucket) - bucket[0] = bs_img - bucket_list.append(bucket) - # multiple images - for num_items in range(min_items, max_items + 1): - for _bucket in base_bucket_list: - bucket = list(_bucket) - bucket[0] = bs_mimg - bucket[1] = num_items - bucket_list.append(bucket) - # spatial resize - if basesize > 256: - ratio = basesize // 256 - - def resize(bucket, r): - bucket[-2] *= r - bucket[-1] *= r - return bucket - bucket_list = [resize(bucket, ratio) for bucket in bucket_list] - return bucket_list - - -def _dynamic_resize_from_bucket(image: Image, basesize: int = 512): - def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: - w, h = img.size # PIL: (width, height) - bh, bw = target_size - scale = max(bh / h, bw / w) - resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) - img = TF.resize(img, (resize_h, resize_w), - interpolation=TF.InterpolationMode.BILINEAR, antialias=True) - img = TF.center_crop(img, target_size) - return img - - bucket_config = generate_video_image_bucket( - basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2 - ) - bucket_group = BucketGroup(bucket_config) - img_w, img_h = image.size - bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w)) - target_height, target_width = bucket[-2], bucket[-1] # (height, width) - img_proc = resize_center_crop(image, (target_height, target_width)) - return img_proc diff --git a/diffsynth/pipelines/joyai_image.py b/diffsynth/pipelines/joyai_image.py index 47e597eea..4d498ea4f 100644 --- a/diffsynth/pipelines/joyai_image.py +++ b/diffsynth/pipelines/joyai_image.py @@ -1,9 +1,8 @@ import torch from PIL import Image -from typing import Union, List +from typing import Union, List, Optional from tqdm import tqdm from einops import rearrange -from typing import Optional, Union from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler @@ -11,12 +10,8 @@ from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..models.joyai_image_dit import Transformer3DModel from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder -from ..models.joyai_image_common import _dynamic_resize_from_bucket from ..models.wan_video_vae import WanVideoVAE -# ============================================================ -# JoyAIImagePipeline -# ============================================================ class JoyAIImagePipeline(BasePipeline): """ Pipeline for JoyAI-Image model. @@ -49,7 +44,9 @@ def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], + # Processor processor_config: ModelConfig = None, + # Optional vram_limit: float = None, ): pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype) @@ -67,54 +64,56 @@ def from_pretrained( pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe - # ============================================================ - # __call__ — Orchestration only - # ============================================================ @torch.no_grad() def __call__( self, + # Prompt prompt: str, negative_prompt: str = "", cfg_scale: float = 5.0, + # Image input_image: Image.Image = None, edit_images: Union[Image.Image, List[Image.Image]] = None, denoising_strength: float = 1.0, + # Shape height: int = 1024, width: int = 1024, + # Randomness seed: int = None, + # Steps max_sequence_length: int = 4096, num_inference_steps: int = 30, + # Tiling tiled: Optional[bool] = False, tile_size: Optional[tuple[int, int]] = (30, 52), tile_stride: Optional[tuple[int, int]] = (15, 26), + # Scheduler shift: Optional[float] = 4.0, + # Progress bar progress_bar_cmd=tqdm, ): - # 1. Scheduler + # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift) - # 2. Three dictionaries + # Parameters inputs_posi = {"prompt": prompt} inputs_nega = {"negative_prompt": negative_prompt} inputs_shared = { "cfg_scale": cfg_scale, - "input_image": input_image, - "edit_images": edit_images, + "input_image": input_image, "edit_images": edit_images, "denoising_strength": denoising_strength, - "height": height, - "width": width, - "seed": seed, - "max_sequence_length": max_sequence_length, + "height": height, "width": width, + "seed": seed, "max_sequence_length": max_sequence_length, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, } - # 3. Unit chain + # Unit chain for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner( unit, self, inputs_shared, inputs_posi, inputs_nega ) - # 4. Denoise loop + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -125,8 +124,8 @@ def __call__( **models, timestep=timestep, progress_id=progress_id ) inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) - - # 5. VAE decode + + # Decode self.load_models_to_device(['vae']) latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w") image = self.vae.decode(latents, device=self.device)[0] @@ -135,9 +134,6 @@ def __call__( return image -# ============================================================ -# PipelineUnits -# ============================================================ class JoyAIImageUnit_ShapeChecker(PipelineUnit): """Validates height/width divisible by 16.""" def __init__(self): @@ -270,9 +266,6 @@ def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image))) return {"latents": noise, "input_latents": input_latents} -# ============================================================ -# model_fn — DiT forward call -# ============================================================ def model_fn_joyai_image( dit, latents, diff --git a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py index 6c89bd6d0..152068f4c 100644 --- a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py @@ -14,23 +14,11 @@ torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="transformer/transformer.pth", - ), - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/model*.safetensors", - ), - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="vae/Wan2.1_VAE.pth", - ), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"), ], - processor_config=ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/", - ), + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), ) # Use first sample from dataset diff --git a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py index e5c9fc451..a1f74f2d1 100644 --- a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py @@ -25,26 +25,11 @@ torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="transformer/transformer.pth", - **vram_config, - ), - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/model*.safetensors", - **vram_config, - ), - ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="vae/Wan2.1_VAE.pth", - **vram_config, - ), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), ], - processor_config=ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/", - ), + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, ) diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh index 59730749a..385bb8e23 100644 --- a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh @@ -1,43 +1,35 @@ # Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/ # Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -DIFFSYNTH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" -cd "$DIFFSYNTH_ROOT" - -export CONDA_DEFAULT_ENV=joyai-image-diffsynth -export PATH="/root/miniconda3/envs/joyai-image-diffsynth/bin:$PATH" -export CUDA_VISIBLE_DEVICES=0 -export PYTHONPATH="$DIFFSYNTH_ROOT:$PYTHONPATH" - -# ==================== 第一阶段:前处理计算 ==================== -# 加载 VAE + TextEncoder,缓存编码结果到硬盘 accelerate launch examples/joyai_image/model_training/train.py \ --dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \ --dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \ - --max_pixels 1048576 --dataset_repeat 1 \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-5 --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-full-cache" \ - --use_gradient_checkpointing --find_unused_parameters \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit-full-cache" \ + --use_gradient_checkpointing \ + --find_unused_parameters \ --data_file_keys "image,edit_images" \ --extra_inputs "edit_images" \ --task "sft:data_process" -# ==================== 第二阶段:DiT 训练 ==================== -# 从缓存读取,仅训练 DiT -# NOTE: Full training of 16B DiT requires DeepSpeed ZeRO-3 with multiple GPUs. -# This script uses single GPU config. If OOM occurs, install DeepSpeed and use -# accelerate_config_zero3.yaml instead. -accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_single_gpu.yaml \ +accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero3.yaml \ examples/joyai_image/model_training/train.py \ --dataset_base_path "./models/train/JoyAI-Image-Edit-full-cache" \ - --max_pixels 1048576 --dataset_repeat 50 \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \ - --learning_rate 1e-5 --num_epochs 2 \ - --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-full" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit-full" \ --trainable_models "dit" \ - --use_gradient_checkpointing --find_unused_parameters \ + --use_gradient_checkpointing \ + --find_unused_parameters \ --data_file_keys "image,edit_images" \ --extra_inputs "edit_images" \ --task "sft:train" diff --git a/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml b/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml index e7ae29c07..e6a8d2733 100644 --- a/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml +++ b/examples/joyai_image/model_training/full/accelerate_config_zero3.yaml @@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: gradient_accumulation_steps: 1 - offload_optimizer_device: cpu + offload_optimizer_device: none offload_param_device: none zero3_init_flag: true zero3_save_16bit_model: true diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh index a461a2fd6..a3c74d652 100644 --- a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh @@ -1,31 +1,39 @@ # Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/ # Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset -# ==================== 第一阶段:前处理计算 ==================== -# 加载 VAE + TextEncoder,缓存编码结果到硬盘 accelerate launch examples/joyai_image/model_training/train.py \ --dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \ --dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \ - --max_pixels 1048576 --dataset_repeat 1 \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \ - --learning_rate 1e-4 --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-split-cache" \ - --lora_base_model "dit" --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" --lora_rank 32 \ - --use_gradient_checkpointing --find_unused_parameters \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit-split-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --find_unused_parameters \ --data_file_keys "image,edit_images" \ --extra_inputs "edit_images" \ --task "sft:data_process" -# ==================== 第二阶段:DiT 训练 ==================== -# 从缓存读取,仅训练 DiT accelerate launch examples/joyai_image/model_training/train.py \ --dataset_base_path "./models/train/JoyAI-Image-Edit-split-cache" \ - --max_pixels 1048576 --dataset_repeat 50 \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ --model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \ - --learning_rate 1e-4 --num_epochs 5 \ - --remove_prefix_in_ckpt "pipe.dit." --output_path "./models/train/JoyAI-Image-Edit-lora" \ - --lora_base_model "dit" --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" --lora_rank 32 \ - --use_gradient_checkpointing --find_unused_parameters \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/JoyAI-Image-Edit-lora" \ + --lora_base_model "dit" \ + --lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --find_unused_parameters \ --data_file_keys "image,edit_images" \ --extra_inputs "edit_images" \ --task "sft:train" diff --git a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py index 92deb79b2..8df50096b 100644 --- a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py @@ -11,10 +11,7 @@ ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"), ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"), ], - processor_config=ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/", - ), + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), ) # Load full training weights diff --git a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py index c215e4fc5..4369b5627 100644 --- a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py @@ -10,10 +10,7 @@ ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"), ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"), ], - processor_config=ModelConfig( - model_id="jd-opensource/JoyAI-Image-Edit", - origin_file_pattern="JoyAI-Image-Und/", - ), + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), ) # Load LoRA weights from dual-stage training output From 8a411831e9031a104a38a05738a76e169a6ef605 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 14 Apr 2026 12:46:09 +0800 Subject: [PATCH 6/8] joyai-image docs --- README.md | 78 +++++++++++++ README_zh.md | 78 +++++++++++++ docs/en/Model_Details/JoyAI-Image.md | 157 +++++++++++++++++++++++++++ docs/en/index.rst | 1 + docs/zh/Model_Details/JoyAI-Image.md | 157 +++++++++++++++++++++++++++ docs/zh/index.rst | 1 + 6 files changed, 472 insertions(+) create mode 100644 docs/en/Model_Details/JoyAI-Image.md create mode 100644 docs/zh/Model_Details/JoyAI-Image.md diff --git a/README.md b/README.md index 39af15add..8ad67e216 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,8 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/). + - **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. - **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/). @@ -875,6 +877,82 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) +#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM. + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +
+ +
+ +Examples + +Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/) + +| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +
+ ## Innovative Achievements DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. diff --git a/README_zh.md b/README_zh.md index 98a61360f..91224e88a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。 + - **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。 - **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。 @@ -876,6 +878,82 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) +#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。 + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +
+ +
+ +示例代码 + +JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/) + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +
+ ## 创新成果 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 diff --git a/docs/en/Model_Details/JoyAI-Image.md b/docs/en/Model_Details/JoyAI-Image.md new file mode 100644 index 000000000..4cf36747b --- /dev/null +++ b/docs/en/Model_Details/JoyAI-Image.md @@ -0,0 +1,157 @@ +# JoyAI-Image + +JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing. + +## Installation + +Before performing model inference and training, please install DiffSynth-Studio first. + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md). + +> **Note**: JoyAI-Image requires a specific version of `transformers`, please install `transformers>=4.57.0,<4.58.0`. + +## Quick Start + +Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM. + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +## Model Overview + +|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation| +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +## Model Inference + +The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details. + +The input parameters for `JoyAIImagePipeline` inference include: + +* `prompt`: Text prompt describing the desired image editing effect. +* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string. +* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt. +* `input_image`: Input image (img2img mode), optional. +* `edit_images`: Image(s) to be edited, can be a single image or a list of images. +* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0. +* `height`: Height of the output image, defaults to 1024. Must be divisible by 16. +* `width`: Width of the output image, defaults to 1024. Must be divisible by 16. +* `seed`: Random seed for reproducibility. Set to `None` for random seed. +* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096. +* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality. +* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False. +* `tile_size`: Tile size, defaults to (30, 52). +* `tile_stride`: Tile stride, defaults to (15, 26). +* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0. +* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm. + +## Model Training + +Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include: + +* General Training Parameters + * Dataset Configuration + * `--dataset_base_path`: Root directory of the dataset. + * `--dataset_metadata_path`: Path to the dataset metadata file. + * `--dataset_repeat`: Number of dataset repeats per epoch. + * `--dataset_num_workers`: Number of processes per DataLoader. + * `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`. + * Model Loading Configuration + * `--model_paths`: Paths to load models from, in JSON format. + * `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas. + * `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`. + * `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients. + * Basic Training Configuration + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`. + * `--find_unused_parameters`: Whether unused parameters exist in DDP training. + * `--weight_decay`: Weight decay magnitude. + * `--task`: Training task, defaults to `sft`. + * Output Configuration + * `--output_path`: Path to save the model. + * `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict. + * `--save_steps`: Interval in training steps to save the model. + * LoRA Configuration + * `--lora_base_model`: Which model to add LoRA to. + * `--lora_target_modules`: Which layers to add LoRA to. + * `--lora_rank`: Rank of LoRA. + * `--lora_checkpoint`: Path to LoRA checkpoint. + * `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training. + * `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`. + * Gradient Configuration + * `--use_gradient_checkpointing`: Whether to enable gradient checkpointing. + * `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory. + * `--gradient_accumulation_steps`: Number of gradient accumulation steps. + * Resolution Configuration + * `--height`: Height of the image/video. Leave empty to enable dynamic resolution. + * `--width`: Width of the image/video. Leave empty to enable dynamic resolution. + * `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution. + * `--num_frames`: Number of frames for video (video generation models only). +* JoyAI-Image Specific Parameters + * `--processor_path`: Path to the processor for processing text and image encoder inputs. + * `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device. + +```shell +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset +``` + +We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/). diff --git a/docs/en/index.rst b/docs/en/index.rst index 4b933cac2..eb428d6d9 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -29,6 +29,7 @@ Welcome to DiffSynth-Studio's Documentation Model_Details/Z-Image Model_Details/Anima Model_Details/LTX-2 + Model_Details/JoyAI-Image .. toctree:: :maxdepth: 2 diff --git a/docs/zh/Model_Details/JoyAI-Image.md b/docs/zh/Model_Details/JoyAI-Image.md new file mode 100644 index 000000000..edd45b3b0 --- /dev/null +++ b/docs/zh/Model_Details/JoyAI-Image.md @@ -0,0 +1,157 @@ +# JoyAI-Image + +JoyAI-Image 是京东开源的统一多模态基础模型,支持图像理解、文生图生成和指令引导的图像编辑。 + +## 安装 + +在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。 + +```shell +git clone https://github.com/modelscope/DiffSynth-Studio.git +cd DiffSynth-Studio +pip install -e . +``` + +更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 + +> **注意**:JoyAI-Image 需要特定版本的 `transformers`,请安装 `transformers>=4.57.0,<4.58.0`。 + +## 快速开始 + +运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。 + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +## 模型推理 + +模型通过 `JoyAIImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。 + +`JoyAIImagePipeline` 推理的输入参数包括: + +* `prompt`: 文本提示词,用于描述期望的图像编辑效果。 +* `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。 +* `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。 +* `input_image`: 输入图像(img2img 模式),可选参数。 +* `edit_images`: 待编辑的图像,可以是单张或多张图片。 +* `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。 +* `height`: 输出图像的高度,默认为 1024。需能被 16 整除。 +* `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。 +* `seed`: 随机种子,用于控制生成的可复现性。设为 `None` 时使用随机种子。 +* `max_sequence_length`: 文本编码器处理的最大序列长度,默认为 4096。 +* `num_inference_steps`: 推理步数,默认为 30。步数越多,生成质量通常越好。 +* `tiled`: 是否启用分块处理,用于降低显存占用,默认为 False。 +* `tile_size`: 分块大小,默认为 (30, 52)。 +* `tile_stride`: 分块步幅,默认为 (15, 26)。 +* `shift`: 调度器的 shift 参数,用于控制 Flow Match 的调度曲线,默认为 4.0。 +* `progress_bar_cmd`: 进度条显示方式,默认为 tqdm。 + +## 模型训练 + +joyai_image 系列模型统一通过 `examples/joyai_image/model_training/train.py` 进行训练,脚本的参数包括: + +* 通用训练参数 + * 数据集基础配置 + * `--dataset_base_path`: 数据集的根目录。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloader 的进程数量。 + * `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。 + * 模型加载配置 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。 + * `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。 + * `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。 + * 训练基础配置 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)。 + * `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。 + * `--weight_decay`: 权重衰减大小。 + * `--task`: 训练任务,默认为 `sft`。 + * 输出配置 + * `--output_path`: 模型保存路径。 + * `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。 + * `--save_steps`: 保存模型的训练步数间隔。 + * LoRA 配置 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪些层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 + * `--lora_checkpoint`: LoRA 检查点的路径。 + * `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。 + * `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。 + * 梯度配置 + * `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + * `--gradient_accumulation_steps`: 梯度累积步数。 + * 分辨率配置 + * `--height`: 图像/视频的高度。留空启用动态分辨率。 + * `--width`: 图像/视频的宽度。留空启用动态分辨率。 + * `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。 + * `--num_frames`: 视频的帧数(仅视频生成模型)。 +* JoyAI-Image 专有参数 + * `--processor_path`: Processor 路径,用于处理文本和图像的编码器输入。 + * `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型,默认在加速设备上初始化。 + +```shell +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset +``` + +关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 42256b3b5..ede4cabcb 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -29,6 +29,7 @@ Model_Details/Z-Image Model_Details/Anima Model_Details/LTX-2 + Model_Details/JoyAI-Image .. toctree:: :maxdepth: 2 From 18227e97274dd1da5c4299f4afbe597759f6369b Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 14 Apr 2026 13:15:52 +0800 Subject: [PATCH 7/8] update readme --- README.md | 274 +++++++++++++++++++++++++-------------------------- README_zh.md | 274 +++++++++++++++++++++++++-------------------------- 2 files changed, 274 insertions(+), 274 deletions(-) diff --git a/README.md b/README.md index d5742dafb..0b243c054 100644 --- a/README.md +++ b/README.md @@ -600,6 +600,143 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/) +#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM. + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +
+ +
+ +Examples + +Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/) + +| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)| +|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—| + +
+ +#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md) + +
+ +Quick Start + +Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM. + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +
+ +
+ +Examples + +Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/) + +| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +
+ ### Video Synthesis https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 @@ -879,143 +1016,6 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) -#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md) - -
- -Quick Start - -Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM. - -```python -from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig -import torch - -vram_config = { - "offload_dtype": torch.bfloat16, - "offload_device": "cpu", - "onload_dtype": torch.bfloat16, - "onload_device": "cpu", - "preparing_dtype": torch.bfloat16, - "preparing_device": "cuda", - "computation_dtype": torch.bfloat16, - "computation_device": "cuda", -} -pipe = ErnieImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device='cuda', - model_configs=[ - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), - ], - tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, -) - -image = pipe( - prompt="一只黑白相间的中华田园犬", - negative_prompt="", - height=1024, - width=1024, - seed=42, - num_inference_steps=50, - cfg_scale=4.0, -) -image.save("output.jpg") -``` - -
- -
- -Examples - -Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/) - -| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | -|-|-|-|-|-|-|-| -|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)| -|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—| - -
- -#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md) - -
- -Quick Start - -Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM. - -```python -from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig -import torch -from PIL import Image -from modelscope import dataset_snapshot_download - -# Download dataset -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/diffsynth_example_dataset", - local_dir="data/diffsynth_example_dataset", - allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" -) - -vram_config = { - "offload_dtype": torch.bfloat16, - "offload_device": "cpu", - "onload_dtype": torch.bfloat16, - "onload_device": "cpu", - "preparing_dtype": torch.bfloat16, - "preparing_device": "cuda", - "computation_dtype": torch.bfloat16, - "computation_device": "cuda", -} - -pipe = JoyAIImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), - ], - processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, -) - -# Use first sample from dataset -dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" -prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") - -output = pipe( - prompt=prompt, - edit_images=[edit_images], - height=1024, - width=1024, - seed=0, - num_inference_steps=30, - cfg_scale=5.0, -) - -output.save("output_joyai_edit_low_vram.png") -``` - -
- -
- -Examples - -Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/) - -| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | -|-|-|-|-|-|-|-| -|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| - -
- ## Innovative Achievements DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. diff --git a/README_zh.md b/README_zh.md index aed549e36..11f78d6a3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -600,6 +600,143 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/) +#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。 + +```python +from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ErnieImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device='cuda', + model_configs=[ + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="一只黑白相间的中华田园犬", + negative_prompt="", + height=1024, + width=1024, + seed=42, + num_inference_steps=50, + cfg_scale=4.0, +) +image.save("output.jpg") +``` + +
+ +
+ +示例代码 + +ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/) + +| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 | +|-|-|-|-|-|-|-| +|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)| +|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—| + +
+ +#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md) + +
+ +快速开始 + +运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。 + +```python +from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig +import torch +from PIL import Image +from modelscope import dataset_snapshot_download + +# Download dataset +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" +) + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +pipe = JoyAIImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), + ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), + ], + processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +# Use first sample from dataset +dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" +prompt = "将裙子改为粉色" +edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") + +output = pipe( + prompt=prompt, + edit_images=[edit_images], + height=1024, + width=1024, + seed=0, + num_inference_steps=30, + cfg_scale=5.0, +) + +output.save("output_joyai_edit_low_vram.png") +``` + +
+ +
+ +示例代码 + +JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/) + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| + +
+ ### 视频生成模型 https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314 @@ -879,143 +1016,6 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) -#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md) - -
- -快速开始 - -运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。 - -```python -from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig -import torch - -vram_config = { - "offload_dtype": torch.bfloat16, - "offload_device": "cpu", - "onload_dtype": torch.bfloat16, - "onload_device": "cpu", - "preparing_dtype": torch.bfloat16, - "preparing_device": "cuda", - "computation_dtype": torch.bfloat16, - "computation_device": "cuda", -} -pipe = ErnieImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device='cuda', - model_configs=[ - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), - ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), - ], - tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, -) - -image = pipe( - prompt="一只黑白相间的中华田园犬", - negative_prompt="", - height=1024, - width=1024, - seed=42, - num_inference_steps=50, - cfg_scale=4.0, -) -image.save("output.jpg") -``` - -
- -
- -示例代码 - -ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/) - -| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 | -|-|-|-|-|-|-|-| -|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)| -|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—| - -
- -#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md) - -
- -快速开始 - -运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。 - -```python -from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig -import torch -from PIL import Image -from modelscope import dataset_snapshot_download - -# Download dataset -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/diffsynth_example_dataset", - local_dir="data/diffsynth_example_dataset", - allow_file_pattern="joyai_image/JoyAI-Image-Edit/*" -) - -vram_config = { - "offload_dtype": torch.bfloat16, - "offload_device": "cpu", - "onload_dtype": torch.bfloat16, - "onload_device": "cpu", - "preparing_dtype": torch.bfloat16, - "preparing_device": "cuda", - "computation_dtype": torch.bfloat16, - "computation_device": "cuda", -} - -pipe = JoyAIImagePipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config), - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config), - ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config), - ], - processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), - vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, -) - -# Use first sample from dataset -dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" -prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") - -output = pipe( - prompt=prompt, - edit_images=[edit_images], - height=1024, - width=1024, - seed=0, - num_inference_steps=30, - cfg_scale=5.0, -) - -output.save("output_joyai_edit_low_vram.png") -``` - -
- -
- -示例代码 - -JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/) - -|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| -|-|-|-|-|-|-|-| -|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)| - -
- ## 创新成果 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 From 8ddd3a67db21039b6c03b0e55eee09ad5386f538 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 14 Apr 2026 17:38:20 +0800 Subject: [PATCH 8/8] pr review --- README.md | 4 +- README_zh.md | 4 +- diffsynth/configs/model_configs.py | 4 +- diffsynth/models/joyai_image_dit.py | 69 +++++++++---------- diffsynth/models/joyai_image_text_encoder.py | 9 ++- diffsynth/pipelines/joyai_image.py | 55 ++++++--------- .../state_dict_converters/joyai_image_dit.py | 24 ------- docs/en/Model_Details/JoyAI-Image.md | 9 +-- docs/zh/Model_Details/JoyAI-Image.md | 9 +-- .../model_inference/JoyAI-Image-Edit.py | 4 +- .../JoyAI-Image-Edit.py | 4 +- .../model_training/full/JoyAI-Image-Edit.sh | 8 +-- .../model_training/lora/JoyAI-Image-Edit.sh | 8 +-- .../validate_full/JoyAI-Image-Edit.py | 6 +- .../validate_lora/JoyAI-Image-Edit.py | 6 +- 15 files changed, 87 insertions(+), 136 deletions(-) delete mode 100644 diffsynth/utils/state_dict_converters/joyai_image_dit.py diff --git a/README.md b/README.md index 0b243c054..ff469777c 100644 --- a/README.md +++ b/README.md @@ -708,11 +708,11 @@ pipe = JoyAIImagePipeline.from_pretrained( # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, diff --git a/README_zh.md b/README_zh.md index 11f78d6a3..77dfaf0f9 100644 --- a/README_zh.md +++ b/README_zh.md @@ -708,11 +708,11 @@ pipe = JoyAIImagePipeline.from_pretrained( # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 7681a5f30..5fc95c3ed 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -905,9 +905,7 @@ # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth") "model_hash": "56592ddfd7d0249d3aa527d24161a863", "model_name": "joyai_image_dit", - "model_class": "diffsynth.models.joyai_image_dit.Transformer3DModel", - "extra_kwargs": {"patch_size": [1, 2, 2], "in_channels": 16, "out_channels": 16, "hidden_size": 4096, "heads_num": 32, "text_states_dim": 4096, "mlp_width_ratio": 4.0, "mm_double_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "rope_type": "rope", "dit_modulation_type": "wanx", "theta": 10000}, - "state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_dit.JoyAIImageDiTStateDictConverter", + "model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT", }, { # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors") diff --git a/diffsynth/models/joyai_image_dit.py b/diffsynth/models/joyai_image_dit.py index b025b9b04..ba88f6060 100644 --- a/diffsynth/models/joyai_image_dit.py +++ b/diffsynth/models/joyai_image_dit.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Tuple, Union, Dict +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -108,6 +108,18 @@ def forward(self, caption): return hidden_states +class GELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = F.gelu(hidden_states, approximate=self.approximate) + return hidden_states + + class FeedForward(nn.Module): def __init__( self, @@ -124,47 +136,28 @@ def __init__( if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - if activation_fn == "gelu-approximate": - self.proj = nn.Linear(dim, inner_dim, bias=bias) - self.act = lambda x: F.gelu(x, approximate="tanh") - elif activation_fn == "gelu": - self.proj = nn.Linear(dim, inner_dim, bias=bias) - self.act = F.gelu + + # Build activation + projection matching diffusers pattern + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) else: - self.proj = nn.Linear(dim, inner_dim, bias=bias) - self.act = F.gelu - self.drop = nn.Dropout(dropout) - self.out_proj = nn.Linear(inner_dim, dim_out, bias=bias) + act_fn = GELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + self.net.append(act_fn) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) if final_dropout: - self.final_drop = nn.Dropout(dropout) - else: - self.final_drop = None + self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - hidden_states = self.proj(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.drop(hidden_states) - hidden_states = self.out_proj(hidden_states) - if self.final_drop is not None: - hidden_states = self.final_drop(hidden_states) + for module in self.net: + hidden_states = module(hidden_states) return hidden_states - -def get_cu_seqlens(text_mask, img_len): - batch_size = text_mask.shape[0] - text_len = text_mask.sum(dim=1) - max_len = text_mask.shape[1] + img_len - cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") - for i in range(batch_size): - s = text_len[i] + img_len - s1 = i * max_len + s - s2 = (i + 1) * max_len - cu_seqlens[2 * i + 1] = s1 - cu_seqlens[2 * i + 2] = s2 - return cu_seqlens - - def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim @@ -495,14 +488,14 @@ def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): return temb, timestep_proj, encoder_hidden_states -class Transformer3DModel(nn.Module): +class JoyAIImageDiT(nn.Module): _supports_gradient_checkpointing = True def __init__( self, patch_size: list = [1, 2, 2], in_channels: int = 16, - out_channels: int = None, + out_channels: int = 16, hidden_size: int = 4096, heads_num: int = 32, text_states_dim: int = 4096, @@ -513,7 +506,7 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, dit_modulation_type: str = "wanx", - theta: int = 256, + theta: int = 10000, ): super().__init__() self.out_channels = out_channels or in_channels diff --git a/diffsynth/models/joyai_image_text_encoder.py b/diffsynth/models/joyai_image_text_encoder.py index 71e606340..292c5f6bd 100644 --- a/diffsynth/models/joyai_image_text_encoder.py +++ b/diffsynth/models/joyai_image_text_encoder.py @@ -67,11 +67,16 @@ def forward( image_grid_thw: Optional[torch.LongTensor] = None, **kwargs, ): - outputs = self.model( + pre_norm_output = [None] + def hook_fn(module, args, kwargs_output=None): + pre_norm_output[0] = args[0] + self.model.model.language_model.norm.register_forward_hook(hook_fn) + _ = self.model( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, attention_mask=attention_mask, + output_hidden_states=True, **kwargs, ) - return outputs + return pre_norm_output[0] \ No newline at end of file diff --git a/diffsynth/pipelines/joyai_image.py b/diffsynth/pipelines/joyai_image.py index 4d498ea4f..734a8e425 100644 --- a/diffsynth/pipelines/joyai_image.py +++ b/diffsynth/pipelines/joyai_image.py @@ -1,6 +1,6 @@ import torch from PIL import Image -from typing import Union, List, Optional +from typing import Union, Optional from tqdm import tqdm from einops import rearrange @@ -8,14 +8,11 @@ from ..diffusion import FlowMatchScheduler from ..core import ModelConfig from ..diffusion.base_pipeline import BasePipeline, PipelineUnit -from ..models.joyai_image_dit import Transformer3DModel +from ..models.joyai_image_dit import JoyAIImageDiT from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder from ..models.wan_video_vae import WanVideoVAE class JoyAIImagePipeline(BasePipeline): - """ - Pipeline for JoyAI-Image model. - """ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): super().__init__( @@ -24,7 +21,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): ) self.scheduler = FlowMatchScheduler("Wan") self.text_encoder: JoyAIImageTextEncoder = None - self.dit: Transformer3DModel = None + self.dit: JoyAIImageDiT = None self.vae: WanVideoVAE = None self.processor = None self.in_iteration_models = ("dit",) @@ -72,8 +69,7 @@ def __call__( negative_prompt: str = "", cfg_scale: float = 5.0, # Image - input_image: Image.Image = None, - edit_images: Union[Image.Image, List[Image.Image]] = None, + edit_image: Image.Image = None, denoising_strength: float = 1.0, # Shape height: int = 1024, @@ -100,7 +96,7 @@ def __call__( inputs_nega = {"negative_prompt": negative_prompt} inputs_shared = { "cfg_scale": cfg_scale, - "input_image": input_image, "edit_images": edit_images, + "edit_image": edit_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "max_sequence_length": max_sequence_length, @@ -135,7 +131,6 @@ def __call__( class JoyAIImageUnit_ShapeChecker(PipelineUnit): - """Validates height/width divisible by 16.""" def __init__(self): super().__init__( input_params=("height", "width"), @@ -162,24 +157,24 @@ def __init__(self): seperate_cfg=True, input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, - input_params=("edit_images", "max_sequence_length"), + input_params=("edit_image", "max_sequence_length"), output_params=("prompt_embeds", "prompt_embeds_mask"), onload_model_names=("joyai_image_text_encoder",), ) - def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_images, max_sequence_length): + def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length): pipe.load_models_to_device(self.onload_model_names) - has_image = edit_images is not None + has_image = edit_image is not None if has_image: - prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_images, max_sequence_length) + prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length) else: prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length) return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask} - def _encode_with_image(self, pipe, prompt, edit_images, max_sequence_length): + def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length): template = self.prompt_template_encode['multiple_images'] drop_idx = self.prompt_template_encode_start_idx['multiple_images'] @@ -187,9 +182,8 @@ def _encode_with_image(self, pipe, prompt, edit_images, max_sequence_length): prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n" prompt = prompt.replace('\n', '<|vision_start|><|image_pad|><|vision_end|>') prompt = template.format(prompt) - inputs = pipe.processor(text=[prompt], images=edit_images, padding=True, return_tensors="pt").to(pipe.device) - encoder_hidden_states = pipe.text_encoder(**inputs, output_hidden_states=True) - last_hidden_states = encoder_hidden_states.hidden_states[-1] + inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device) + last_hidden_states = pipe.text_encoder(**inputs) prompt_embeds = last_hidden_states[:, drop_idx:] prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:] @@ -202,35 +196,29 @@ def _encode_with_image(self, pipe, prompt, edit_images, max_sequence_length): def _encode_text_only(self, pipe, prompt, max_sequence_length): # TODO: may support for text-only encoding in the future. - raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_images for now.") + raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.") return prompt_embeds, encoder_attention_mask class JoyAIImageUnit_EditImageEmbedder(PipelineUnit): - """ - Encodes edit images into reference latents using VAE. - """ def __init__(self): super().__init__( - input_params=("edit_images", "tiled", "tile_size", "tile_stride", "edit_image_basesize", "height", "width"), + input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"), output_params=("ref_latents", "num_items", "is_multi_item"), onload_model_names=("wan_video_vae",), ) - def process(self, pipe: "JoyAIImagePipeline", edit_images, tiled, tile_size, tile_stride, edit_image_basesize, height, width): - if edit_images is None: + def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width): + if edit_image is None: return {} - if isinstance(edit_images, Image.Image): - edit_images = [edit_images] pipe.load_models_to_device(self.onload_model_names) - assert len(edit_images) == 1, "Currently only supports single edit image for reference. Multiple edit images will be supported in the future." - # Resize edit images to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents - edit_images = [img.resize((width, height), Image.LANCZOS) for img in edit_images] - images = [pipe.preprocess_image(img).transpose(0, 1) for img in edit_images] + # Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents + edit_image = edit_image.resize((width, height), Image.LANCZOS) + images = [pipe.preprocess_image(edit_image).transpose(0, 1)] latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(edit_images))).to(device=pipe.device, dtype=pipe.torch_dtype) + ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype) - return {"ref_latents": ref_vae, "edit_images": edit_images} + return {"ref_latents": ref_vae, "edit_image": edit_image} class JoyAIImageUnit_NoiseInitializer(PipelineUnit): @@ -239,6 +227,7 @@ def __init__(self): input_params=("seed", "height", "width", "rand_device"), output_params=("noise"), ) + def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device): latent_h = height // pipe.vae.upsampling_factor latent_w = width // pipe.vae.upsampling_factor diff --git a/diffsynth/utils/state_dict_converters/joyai_image_dit.py b/diffsynth/utils/state_dict_converters/joyai_image_dit.py deleted file mode 100644 index 77921f26b..000000000 --- a/diffsynth/utils/state_dict_converters/joyai_image_dit.py +++ /dev/null @@ -1,24 +0,0 @@ -def JoyAIImageDiTStateDictConverter(state_dict): - """Convert JoyAI-Image DiT checkpoint to model state dict. - - Handle: - 1. "model." prefix stripping from checkpoint - 2. FeedForward key mapping: diffusers uses "net.0.proj" / "net.2" - while DiffSynth uses "proj" / "out_proj" - """ - state_dict_ = {} - for name in state_dict: - if name.startswith("model."): - name = name[len("model."):] - - # Map diffusers FeedForward keys to DiffSynth keys - # Pattern: double_blocks.N.{img_mlp|txt_mlp}.net.0.proj.* -> double_blocks.N.{img_mlp|txt_mlp}.proj.* - new_name = name - if ".net.0.proj." in name: - new_name = name.replace(".net.0.proj.", ".proj.") - elif ".net.2." in name: - new_name = name.replace(".net.2.", ".out_proj.") - - state_dict_[new_name] = state_dict[name] - - return state_dict_ diff --git a/docs/en/Model_Details/JoyAI-Image.md b/docs/en/Model_Details/JoyAI-Image.md index 4cf36747b..bf34a064e 100644 --- a/docs/en/Model_Details/JoyAI-Image.md +++ b/docs/en/Model_Details/JoyAI-Image.md @@ -14,8 +14,6 @@ pip install -e . For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md). -> **Note**: JoyAI-Image requires a specific version of `transformers`, please install `transformers>=4.57.0,<4.58.0`. - ## Quick Start Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM. @@ -59,11 +57,11 @@ pipe = JoyAIImagePipeline.from_pretrained( # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, @@ -89,8 +87,7 @@ The input parameters for `JoyAIImagePipeline` inference include: * `prompt`: Text prompt describing the desired image editing effect. * `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string. * `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt. -* `input_image`: Input image (img2img mode), optional. -* `edit_images`: Image(s) to be edited, can be a single image or a list of images. +* `edit_image`: Image to be edited. * `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0. * `height`: Height of the output image, defaults to 1024. Must be divisible by 16. * `width`: Width of the output image, defaults to 1024. Must be divisible by 16. diff --git a/docs/zh/Model_Details/JoyAI-Image.md b/docs/zh/Model_Details/JoyAI-Image.md index edd45b3b0..8904862a3 100644 --- a/docs/zh/Model_Details/JoyAI-Image.md +++ b/docs/zh/Model_Details/JoyAI-Image.md @@ -14,8 +14,6 @@ pip install -e . 更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。 -> **注意**:JoyAI-Image 需要特定版本的 `transformers`,请安装 `transformers>=4.57.0,<4.58.0`。 - ## 快速开始 运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。 @@ -59,11 +57,11 @@ pipe = JoyAIImagePipeline.from_pretrained( # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, @@ -89,8 +87,7 @@ output.save("output_joyai_edit_low_vram.png") * `prompt`: 文本提示词,用于描述期望的图像编辑效果。 * `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。 * `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。 -* `input_image`: 输入图像(img2img 模式),可选参数。 -* `edit_images`: 待编辑的图像,可以是单张或多张图片。 +* `edit_image`: 待编辑的单张图像。 * `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。 * `height`: 输出图像的高度,默认为 1024。需能被 16 整除。 * `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。 diff --git a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py index 152068f4c..a8aa0d7dc 100644 --- a/examples/joyai_image/model_inference/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference/JoyAI-Image-Edit.py @@ -24,11 +24,11 @@ # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=1, diff --git a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py index a1f74f2d1..d6a67bcc9 100644 --- a/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py @@ -36,11 +36,11 @@ # Use first sample from dataset dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" prompt = "将裙子改为粉色" -edit_images = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") +edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB") output = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, diff --git a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh index 385bb8e23..1af9f1829 100644 --- a/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh @@ -13,8 +13,8 @@ accelerate launch examples/joyai_image/model_training/train.py \ --output_path "./models/train/JoyAI-Image-Edit-full-cache" \ --use_gradient_checkpointing \ --find_unused_parameters \ - --data_file_keys "image,edit_images" \ - --extra_inputs "edit_images" \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --task "sft:data_process" accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero3.yaml \ @@ -30,6 +30,6 @@ accelerate launch --config_file examples/joyai_image/model_training/full/acceler --trainable_models "dit" \ --use_gradient_checkpointing \ --find_unused_parameters \ - --data_file_keys "image,edit_images" \ - --extra_inputs "edit_images" \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --task "sft:train" diff --git a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh index a3c74d652..b9a5fd431 100644 --- a/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh +++ b/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh @@ -16,8 +16,8 @@ accelerate launch examples/joyai_image/model_training/train.py \ --lora_rank 32 \ --use_gradient_checkpointing \ --find_unused_parameters \ - --data_file_keys "image,edit_images" \ - --extra_inputs "edit_images" \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --task "sft:data_process" accelerate launch examples/joyai_image/model_training/train.py \ @@ -34,6 +34,6 @@ accelerate launch examples/joyai_image/model_training/train.py \ --lora_rank 32 \ --use_gradient_checkpointing \ --find_unused_parameters \ - --data_file_keys "image,edit_images" \ - --extra_inputs "edit_images" \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --task "sft:train" diff --git a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py index 8df50096b..f245deb3f 100644 --- a/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py @@ -14,17 +14,15 @@ processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), ) -# Load full training weights state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors") pipe.dit.load_state_dict(state_dict) -# Use training dataset prompt and edit_images prompt = "将裙子改为粉色" -edit_images = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") +edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") image = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0, diff --git a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py index 4369b5627..16f06aa64 100644 --- a/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py +++ b/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py @@ -13,16 +13,14 @@ processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"), ) -# Load LoRA weights from dual-stage training output pipe.load_lora(pipe.dit, "models/train/JoyAI-Image-Edit-lora/epoch-4.safetensors") -# Use training dataset prompt and edit_images prompt = "将裙子改为粉色" -edit_images = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") +edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB") image = pipe( prompt=prompt, - edit_images=[edit_images], + edit_image=edit_image, height=1024, width=1024, seed=0,