diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4c1c220d58e8..fa652c071ae2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -353,6 +353,8 @@ title: JoyImageEditTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel + - local: api/models/lens_transformer2d + title: LensTransformer2DModel - local: api/models/longcat_image_transformer2d title: LongCatImageTransformer2DModel - local: api/models/ltx2_video_transformer3d @@ -553,6 +555,8 @@ title: Kandinsky 5.0 Image - local: api/pipelines/kolors title: Kolors + - local: api/pipelines/lens + title: Lens - local: api/pipelines/latent_consistency_models title: Latent Consistency Models - local: api/pipelines/latent_diffusion diff --git a/docs/source/en/api/models/lens_transformer2d.md b/docs/source/en/api/models/lens_transformer2d.md new file mode 100644 index 000000000000..d19a3ba3befa --- /dev/null +++ b/docs/source/en/api/models/lens_transformer2d.md @@ -0,0 +1,23 @@ + + +# LensTransformer2DModel + +A Transformer model for image-like data from [Lens](https://huggingface.co/microsoft/Lens). + +## LensTransformer2DModel + +[[autodoc]] LensTransformer2DModel + +## LensTransformer2DModelOutput + +[[autodoc]] models.transformers.transformer_lens.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/lens.md b/docs/source/en/api/pipelines/lens.md new file mode 100644 index 000000000000..ac402be829c7 --- /dev/null +++ b/docs/source/en/api/pipelines/lens.md @@ -0,0 +1,52 @@ + + +# Lens + +
+
+ +Lens is a 3.8B-parameter foundational text-to-image model designed for efficient training and fast high-resolution generation. It combines dense-caption pre-training, mixed-resolution learning, GPT-OSS multi-layer text features, and the FLUX.2 semantic VAE to reach competitive quality with substantially less training compute than larger T2I models. For more details, please refer to the [model card](https://huggingface.co/microsoft/Lens). + +The abstract from the paper is: + +*Lens is a 3.8B-parameter foundational text-to-image model designed for efficient training and fast high-resolution generation. It combines dense-caption pre-training, mixed-resolution learning, GPT-OSS multi-layer text features, and the FLUX.2 semantic VAE to reach competitive quality with substantially less training compute than larger T2I models.* + +## Usage Example + +```python +import torch +from diffusers import LensPipeline + +pipe = LensPipeline.from_pretrained("microsoft/Lens", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +image = pipe( + prompt="A cat holding a sign that says hello world", + height=1440, + width=1440, + num_inference_steps=20, + guidance_scale=5.0, +).images[0] +image.save("lens.png") +``` + +## LensPipeline + +[[autodoc]] LensPipeline + +- all +- __call__ + +## LensPipelineOutput + +[[autodoc]] pipelines.lens.pipeline_output.LensPipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 5e89f26fce54..0ba4a2854165 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -50,6 +50,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting | | [Kandinsky 3](kandinsky3) | text2image, image2image | | [Kolors](kolors) | text2image | +| [Lens](lens) | text2image | | [Latent Consistency Models](latent_consistency_models) | text2image | | [Latent Diffusion](latent_diffusion) | text2image, super-resolution | | [Latte](latte) | text2image | diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d75e2d9a5010..4bc7f0ff4eb1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -260,6 +260,7 @@ "JoyImageEditTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", + "LensTransformer2DModel", "LatteTransformer3DModel", "LongCatAudioDiTTransformer", "LongCatAudioDiTVae", @@ -621,6 +622,7 @@ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", "LattePipeline", + "LensPipeline", "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", @@ -1096,6 +1098,7 @@ JoyImageEditTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, + LensTransformer2DModel, LatteTransformer3DModel, LongCatAudioDiTTransformer, LongCatAudioDiTVae, @@ -1432,6 +1435,7 @@ LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, LattePipeline, + LensPipeline, LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e661dd1002a2..78c0361bf401 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -119,6 +119,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] + _import_structure["transformers.transformer_lens"] = ["LensTransformer2DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -248,6 +249,7 @@ HunyuanVideoTransformer3DModel, JoyImageEditTransformer3DModel, Kandinsky5Transformer3DModel, + LensTransformer2DModel, LatteTransformer3DModel, LongCatAudioDiTTransformer, LongCatImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 98558c22d7df..228eb8b9d51d 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -41,6 +41,7 @@ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_joyimage import JoyImageEditTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_lens import LensTransformer2DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_lens.py b/src/diffusers/models/transformers/transformer_lens.py new file mode 100644 index 000000000000..3b7eee5996eb --- /dev/null +++ b/src/diffusers/models/transformers/transformer_lens.py @@ -0,0 +1,622 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import apply_lora_scale, logging +from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, RMSNorm + + +logger = logging.get_logger(__name__) + + +def apply_rotary_emb_lens( + x: torch.Tensor, + freqs_cis: torch.Tensor, +) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) + return x_out.type_as(x) + + +class GateMLP(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class LensTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + proj = self.time_proj(timestep) + return self.timestep_embedder(proj.to(dtype=hidden_states.dtype)) + + +class LensEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope: bool = False) -> None: + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.scale_rope = scale_rope + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.register_buffer( + "pos_freqs", + torch.cat([self._rope_params(pos_index, d, theta) for d in axes_dim], dim=1), + persistent=False, + ) + self.register_buffer( + "neg_freqs", + torch.cat([self._rope_params(neg_index, d, theta) for d in axes_dim], dim=1), + persistent=False, + ) + + @staticmethod + def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor: + assert dim % 2 == 0 + freqs = torch.outer( + index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim)) + ) + return torch.polar(torch.ones_like(freqs), freqs) + + @lru_cache_unless_export(maxsize=None) + def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + return self.pos_freqs.to(device), self.neg_freqs.to(device) + + def forward( + self, + video_fhw: list[tuple[int, int, int]] | tuple[int, int, int], + txt_seq_len: int | torch.Tensor, + device: torch.device = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(video_fhw, list) and len(video_fhw) > 1: + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in LensEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + video_freq = self._compute_video_freqs(frame, height, width, idx, device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_txt_seq_len_int = int(txt_seq_len) + pos_freqs_device, _ = self._get_device_freqs(device) + txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @lru_cache_unless_export(maxsize=128) + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: + seq_lens = frame * height * width + pos_freqs, neg_freqs = ( + self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs) + ) + + freqs_pos = pos_freqs.split([d // 2 for d in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([d // 2 for d in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0 + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class LensDoubleStreamAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: "LensJointAttention", + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.FloatTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + bsz, seq_img, _ = hidden_states.shape + seq_txt = encoder_hidden_states.shape[1] + + img_qkv = attn.img_qkv(hidden_states).view(bsz, seq_img, 3, attn.heads, attn.dim_head) + txt_qkv = attn.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, attn.heads, attn.dim_head) + img_q, img_k, img_v = img_qkv.unbind(dim=2) + txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2) + + if attn.norm_q is not None: + img_q = attn.norm_q(img_q) + if attn.norm_k is not None: + img_k = attn.norm_k(img_k) + if attn.norm_added_q is not None: + txt_q = attn.norm_added_q(txt_q) + if attn.norm_added_k is not None: + txt_k = attn.norm_added_k(txt_k) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + if img_freqs.shape[0] < seq_img: + raise ValueError( + f"Image RoPE length {img_freqs.shape[0]} is shorter than image sequence length {seq_img}." + ) + img_freqs = img_freqs[:seq_img] + img_q = apply_rotary_emb_lens(img_q, img_freqs) + img_k = apply_rotary_emb_lens(img_k, img_freqs) + if seq_txt > 0: + if txt_freqs.shape[0] < seq_txt: + raise ValueError( + f"Text RoPE length {txt_freqs.shape[0]} is shorter than text sequence length {seq_txt}." + ) + txt_freqs = txt_freqs[:seq_txt] + txt_q = apply_rotary_emb_lens(txt_q, txt_freqs) + txt_k = apply_rotary_emb_lens(txt_k, txt_freqs) + + joint_query = torch.cat([img_q, txt_q], dim=1) + joint_key = torch.cat([img_k, txt_k], dim=1) + joint_value = torch.cat([img_v, txt_v], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + img_attn_output = joint_hidden_states[:, :seq_img, :] + txt_attn_output = joint_hidden_states[:, seq_img:, :] + + img_attn_output = attn.to_out[0](img_attn_output.contiguous()) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) + + txt_attn_output = attn.to_add_out(txt_attn_output.contiguous()) + + return img_attn_output, txt_attn_output + + +class LensJointAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = LensDoubleStreamAttnProcessor + _available_processors = [LensDoubleStreamAttnProcessor] + + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + dim_head: int = 64, + heads: int = 8, + out_dim: int | None = None, + eps: float = 1e-5, + ) -> None: + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = self.inner_dim // dim_head + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + + self.img_qkv = nn.Linear(query_dim, 3 * self.inner_dim, bias=True) + self.txt_qkv = nn.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, self.out_dim, bias=True), nn.Identity()]) + self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=True) + + self.set_processor(LensDoubleStreamAttnProcessor()) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.processor( + self, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + + +@maybe_allow_in_graph +class LensTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + rms_norm: bool = False, + gate_mlp: bool = False, + ) -> None: + super().__init__() + self.attn = LensJointAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + eps=eps, + ) + + norm_cls = (lambda d: RMSNorm(d, eps=eps)) if rms_norm else ( + lambda d: nn.LayerNorm(d, elementwise_affine=False, eps=eps) + ) + if gate_mlp: + mlp_cls = lambda: GateMLP(dim, int(dim / 3 * 8)) + else: + mlp_cls = lambda: FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.img_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + self.img_norm1 = norm_cls(dim) + self.img_norm2 = norm_cls(dim) + self.img_mlp = mlp_cls() + + self.txt_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + self.txt_norm1 = norm_cls(dim) + self.txt_norm2 = norm_cls(dim) + self.txt_mlp = mlp_cls() + + @staticmethod + def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1) + txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1) + + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + + joint_attention_kwargs = joint_attention_kwargs or {} + block_attention_mask = joint_attention_kwargs.pop("attention_mask", attention_mask) + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + image_rotary_emb=image_rotary_emb, + attention_mask=block_attention_mask, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + img_gate1 * img_attn + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LensTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + The Transformer model introduced in Lens. + + Reference: https://huggingface.co/microsoft/Lens + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `32`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `48`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `64`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + inner_dim (`int`, defaults to `1536`): + The inner dimension of the transformer. If not specified, it defaults to + `num_attention_heads * attention_head_dim`. + enc_hidden_dim (`int`, defaults to `2880`): + The hidden dimension of the text encoder outputs. + axes_dims_rope (`tuple[int, int, int]`, defaults to `(8, 28, 28)`): + The dimensions to use for the rotary positional embeddings for frame, height, and width axes. + gate_mlp (`bool`, defaults to `True`): + Whether to use a gated MLP (SwiGLU) in the transformer blocks. + rms_norm (`bool`, defaults to `True`): + Whether to use RMS normalization instead of LayerNorm in the transformer blocks. + multi_layer_encoder_feature (`bool`, defaults to `True`): + Whether to use multi-layer text encoder features by selecting specific layers. + selected_layer_index (`tuple[int, ...]`, defaults to `(5, 11, 17, 23)`): + The indices of the text encoder layers to select when `multi_layer_encoder_feature` is True. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["LensTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["LensTransformerBlock"] + _cp_plan = { + "transformer_blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "pos_embed": { + 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 128, + out_channels: int | None = 32, + num_layers: int = 48, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + inner_dim: int = 1536, + enc_hidden_dim: int = 2880, + axes_dims_rope: tuple[int, int, int] = (8, 28, 28), + gate_mlp: bool = True, + rms_norm: bool = True, + multi_layer_encoder_feature: bool = True, + selected_layer_index: tuple[int, ...] = (5, 11, 17, 23), + ) -> None: + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.multi_layer_encoder_feature = multi_layer_encoder_feature + self.selected_layer_index = list(selected_layer_index) + + self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + self.time_text_embed = LensTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + if self.multi_layer_encoder_feature: + self.txt_norm = nn.ModuleList( + [RMSNorm(enc_hidden_dim, eps=1e-5) for _ in self.selected_layer_index] + ) + self.txt_in = nn.Linear(enc_hidden_dim * len(self.selected_layer_index), self.inner_dim) + else: + self.txt_norm = RMSNorm(enc_hidden_dim, eps=1e-5) + self.txt_in = nn.Linear(enc_hidden_dim, self.inner_dim) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + LensTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rms_norm=rms_norm, + gate_mlp=gate_mlp, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + encoder_hidden_states_mask: torch.Tensor | None = None, + timestep: torch.LongTensor = None, + img_shapes: list[tuple[int, int, int]] | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`LensTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` or `list[torch.Tensor]`): + Conditional embeddings computed from the input conditions such as prompts. When + `multi_layer_encoder_feature` is True, a list of per-layer text tensors is expected. + encoder_hidden_states_mask (`torch.Tensor`, *optional*): + Boolean mask for the encoder hidden states, where `True` indicates valid tokens. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + img_shapes (`list[tuple[int, int, int]]`, *optional*): + List of (frame, height, width) tuples for each image in the batch, used to compute + rotary positional embeddings. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + bsz, img_len, _ = hidden_states.shape + + if self.multi_layer_encoder_feature: + if not isinstance(encoder_hidden_states, (list, tuple)): + raise ValueError( + "multi_layer_encoder_feature=True expects a list of per-layer text tensors." + ) + if len(encoder_hidden_states) != len(self.selected_layer_index): + raise ValueError( + f"Expected {len(self.selected_layer_index)} text feature layers, " + f"got {len(encoder_hidden_states)}." + ) + text_seq_len = encoder_hidden_states[0].shape[1] + else: + if not isinstance(encoder_hidden_states, torch.Tensor): + raise ValueError( + "multi_layer_encoder_feature=False expects a single text feature tensor." + ) + text_seq_len = encoder_hidden_states.shape[1] + + attention_mask = None + if encoder_hidden_states_mask is not None: + attention_mask = self._build_joint_attention_mask(encoder_hidden_states_mask, img_len) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + if self.multi_layer_encoder_feature: + normed = [ + self.txt_norm[i](encoder_hidden_states[i]) + for i in range(len(self.selected_layer_index)) + ] + encoder_hidden_states = torch.cat(normed, dim=-1) + else: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + temb = self.time_text_embed(timestep, hidden_states) + + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device) + + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if attention_mask is not None: + block_attention_kwargs["attention_mask"] = attention_mask + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + None, + block_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=block_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + @staticmethod + def _build_joint_attention_mask( + text_mask: torch.Tensor, img_len: int + ) -> torch.Tensor: + if text_mask.dtype != torch.bool: + text_mask = text_mask.bool() + bsz = text_mask.shape[0] + img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device) + joint = torch.cat([img_ones, text_mask], dim=1) + additive = torch.zeros_like(joint, dtype=torch.float32) + additive.masked_fill_(~joint, float("-inf")) + return additive[:, None, None, :] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 720548e38fd4..45528e903915 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -324,6 +324,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["lens"] = ["LensPipeline"] _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", @@ -783,6 +784,7 @@ ) from .latent_diffusion import LDMTextToImagePipeline from .latte import LattePipeline + from .lens import LensPipeline from .ledits_pp import ( LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput, diff --git a/src/diffusers/pipelines/lens/__init__.py b/src/diffusers/pipelines/lens/__init__.py new file mode 100644 index 000000000000..31ad5fe7b960 --- /dev/null +++ b/src/diffusers/pipelines/lens/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["LensPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lens"] = ["LensGptOssEncoder", "LensPipeline"] + + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_lens import LensGptOssEncoder, LensPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/lens/pipeline_lens.py b/src/diffusers/pipelines/lens/pipeline_lens.py new file mode 100644 index 000000000000..4616e69c949f --- /dev/null +++ b/src/diffusers/pipelines/lens/pipeline_lens.py @@ -0,0 +1,680 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence, Union + +import numpy as np +import torch +from transformers import PreTrainedTokenizerBase + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLFlux2, LensTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LensPipelineOutput +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + + + +logger = logging.get_logger(__name__) + + +class LensGptOssEncoder(GptOssForCausalLM): + """`GptOssForCausalLM` subclass that exposes selected hidden states. + + This text encoder extracts hidden states from specific intermediate layers of the + GPT-OSS model, enabling multi-layer text features for the Lens transformer. It + early-exits after the last selected layer to avoid unnecessary computation. + + Call `set_selected_layers` before the first forward pass to configure which + layers to capture. + """ + + def set_selected_layers(self, layer_indices: Sequence[int]) -> None: + layers = [int(i) for i in layer_indices] + if not layers: + raise ValueError("layer_indices must be non-empty") + if len(set(layers)) != len(layers): + raise ValueError(f"layer_indices must be unique; got {layers}") + if min(layers) < 0 or max(layers) >= len(self.model.layers): + raise ValueError( + f"layer_indices out of range; got {layers}, " + f"model has {len(self.model.layers)} layers" + ) + self._lens_selected_layers = layers + self._lens_max_layer = max(layers) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> list[torch.Tensor]: + """The [`LensGptOssEncoder`] forward method. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + + Returns: + `list[torch.Tensor]`: Hidden states at the configured selected layers. + """ + if not hasattr(self, "_lens_selected_layers"): + raise RuntimeError("Call set_selected_layers(...) before forward().") + + target_device = self.model.embed_tokens.weight.device + if target_device.type != "meta": + if input_ids is not None and input_ids.device != target_device: + input_ids = input_ids.to(target_device) + if attention_mask is not None and attention_mask.device != target_device: + attention_mask = attention_mask.to(target_device) + + model = self.model + inputs_embeds = model.embed_tokens(input_ids) + position_ids = torch.arange( + inputs_embeds.shape[1], device=inputs_embeds.device + ).unsqueeze(0).expand_as(input_ids) + + mask_kwargs = { + "config": model.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": None, + "position_ids": position_ids, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + position_embeddings = model.rotary_emb(hidden_states, position_ids) + + captured: list[torch.Tensor | None] = [None] * len(self._lens_selected_layers) + index_lookup = {idx: pos for pos, idx in enumerate(self._lens_selected_layers)} + + for i, decoder_layer in enumerate(model.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[model.config.layer_types[i]], + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=None, + use_cache=False, + ) + if i in index_lookup: + captured[index_lookup[i]] = hidden_states + if i == self._lens_max_layer: + break + + for pos, layer_idx in enumerate(self._lens_selected_layers): + if captured[pos] is None: + raise RuntimeError( + f"Failed to capture hidden state for layer {layer_idx}" + ) + return captured + + +import transformers as _transformers + +if not hasattr(_transformers, "LensGptOssEncoder"): + _transformers.LensGptOssEncoder = LensGptOssEncoder + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LensPipeline + + >>> pipe = LensPipeline.from_pretrained("microsoft/Lens", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("lens.png") + ``` +""" + + +RESOLUTION_BUCKETS = { + 1024: { + "1:2": (1472, 736), + "9:16": (1376, 768), + "2:3": (1248, 832), + "3:4": (1152, 864), + "1:1": (1024, 1024), + "4:3": (864, 1152), + "3:2": (832, 1248), + "16:9": (768, 1376), + "2:1": (736, 1472), + }, + 1440: { + "1:2": (2080, 1040), + "9:16": (1936, 1088), + "2:3": (1760, 1168), + "3:4": (1616, 1216), + "1:1": (1440, 1440), + "4:3": (1216, 1616), + "3:2": (1168, 1760), + "16:9": (1088, 1936), + "2:1": (1040, 2080), + }, +} + + +def resolve_resolution(base_resolution: int, aspect_ratio: str): + if base_resolution not in RESOLUTION_BUCKETS: + raise ValueError( + f"Unsupported base_resolution={base_resolution}. " + f"Supported: {tuple(RESOLUTION_BUCKETS.keys())}" + ) + table = RESOLUTION_BUCKETS[base_resolution] + if aspect_ratio not in table: + raise ValueError( + f"Unsupported aspect_ratio={aspect_ratio!r}. " + f"Supported: {tuple(RESOLUTION_BUCKETS[1024].keys())}" + ) + return table[aspect_ratio] + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + if image_seq_len > 4300: + return float(a2 * image_seq_len + b2) + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + return float(a * num_steps + b) + + +CHAT_SYSTEM = ( + "Describe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background." +) +CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description." +DEFAULT_TXT_OFFSET = 97 + + +class LensPipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + @property + def guidance_scale(self) -> float: + return self._guidance_scale + + @property + def num_timesteps(self) -> int: + return self._num_timesteps + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder, + tokenizer: PreTrainedTokenizerBase, + transformer: LensTransformer2DModel, + ): + super().__init__() + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + if self.tokenizer is not None and self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + if self.tokenizer is not None: + self.tokenizer.padding_side = "right" + self.vae_scale_factor = 16 + self.latent_channels = self.transformer.config.in_channels if self.transformer is not None else None + self.txt_offset = DEFAULT_TXT_OFFSET + + if self.text_encoder is not None and hasattr(self.text_encoder, "set_selected_layers"): + if self.transformer is not None: + selected_layers = list(self.transformer.config.selected_layer_index) + else: + selected_layers = [0] + self.text_encoder.set_selected_layers(selected_layers) + self.default_sample_size = 1024 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _build_chat_inputs( + self, prompts: list[str], max_sequence_length: int, device: torch.device + ): + rendered = [] + for prompt in prompts: + conversation = [ + {"role": "system", "content": CHAT_SYSTEM, "thinking": None}, + {"role": "user", "content": prompt, "thinking": None}, + {"role": "assistant", "thinking": CHAT_ASSISTANT_THINKING, "content": ""}, + ] + text = self.tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=False + ) + text = text.split("<|return|>")[0] + rendered.append(text) + + encoded = self.tokenizer( + rendered, + padding=True, + truncation=True, + max_length=max_sequence_length, + return_tensors="pt", + add_special_tokens=True, + ) + return encoded["input_ids"].to(device), encoded["attention_mask"].to(device) + + def _get_text_embeddings( + self, prompts: list[str], max_sequence_length: int, device: torch.device + ): + input_ids, attn_mask = self._build_chat_inputs(prompts, max_sequence_length, device) + layer_outputs = self.text_encoder(input_ids, attention_mask=attn_mask) + + offset = self.txt_offset + if input_ids.shape[1] > offset: + features = [feat[:, offset:, :].contiguous() for feat in layer_outputs] + mask = attn_mask[:, offset:].bool() + else: + zero_shape = (input_ids.shape[0], 0, layer_outputs[0].shape[-1]) + features = [layer_outputs[0].new_zeros(zero_shape) for _ in layer_outputs] + mask = torch.zeros((input_ids.shape[0], 0), dtype=torch.bool, device=device) + return features, mask + + def encode_prompt( + self, + prompt: Union[str, list[str]], + negative_prompt: Union[str, list[str]] = "", + num_images_per_prompt: int = 1, + prompt_embeds: Optional[list[torch.Tensor]] = None, + prompt_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[list[torch.Tensor]] = None, + negative_prompt_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompts = None + if prompt_embeds is None: + prompts = [prompt] if isinstance(prompt, str) else list(prompt) + n = int(num_images_per_prompt) + + negatives = None + if negative_prompt_embeds is None: + prompt_batch_size = len(prompts) if prompts is not None else prompt_embeds[0].shape[0] + if isinstance(negative_prompt, str): + negatives = [negative_prompt] * prompt_batch_size + else: + negatives = list(negative_prompt) + if len(negatives) == 1: + negatives = negatives * prompt_batch_size + if len(negatives) != prompt_batch_size: + raise ValueError( + "negative_prompt must be a string or a list of the same length as prompt" + ) + + if prompt_embeds is None: + prompt_embeds, prompt_mask = self._get_text_embeddings(prompts, max_sequence_length, device) + prompt_embeds, prompt_mask = self._repeat_for_n(prompt_embeds, prompt_mask, n) + elif prompt_mask is None: + raise ValueError("`prompt_mask` must be provided when passing `prompt_embeds`.") + + if negative_prompt_embeds is None: + if all(isinstance(neg, str) and not neg.strip() for neg in negatives): + negative_prompt_embeds = [feat.new_zeros(feat.shape) for feat in prompt_embeds] + negative_prompt_mask = torch.zeros_like(prompt_mask, dtype=torch.bool) + else: + negative_prompt_embeds, negative_prompt_mask = self._get_text_embeddings( + negatives, max_sequence_length, device + ) + negative_prompt_embeds, negative_prompt_mask = self._repeat_for_n( + negative_prompt_embeds, negative_prompt_mask, n + ) + elif negative_prompt_mask is None: + raise ValueError( + "`negative_prompt_mask` must be provided when passing `negative_prompt_embeds`." + ) + return prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask + + @staticmethod + def _repeat_for_n(features: list[torch.Tensor], mask: torch.Tensor, n: int): + if n == 1: + return features, mask + features = [f.repeat_interleave(n, dim=0) for f in features] + mask = mask.repeat_interleave(n, dim=0) + return features, mask + + @staticmethod + def _align_text_features( + pos_features: list[torch.Tensor], + pos_mask: torch.Tensor, + neg_features: list[torch.Tensor], + neg_mask: torch.Tensor, + ): + seq_pos = pos_features[0].shape[1] + seq_neg = neg_features[0].shape[1] + target = max(seq_pos, seq_neg) + + def pad(features, cur): + if cur == target: + return features + pad_len = target - cur + return [ + torch.cat([feat, feat.new_zeros((feat.shape[0], pad_len, feat.shape[-1]))], dim=1) + for feat in features + ] + + def pad_mask(mask, cur): + if cur == target: + return mask + return torch.cat( + [mask, torch.zeros((mask.shape[0], target - cur), dtype=torch.bool, device=mask.device)], dim=1 + ) + + pos_features = pad(pos_features, seq_pos) + neg_features = pad(neg_features, seq_neg) + pos_mask = pad_mask(pos_mask.bool(), seq_pos) + neg_mask = pad_mask(neg_mask.bool(), seq_neg) + return pos_features, pos_mask, neg_features, neg_mask + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + shape = (batch_size, latent_h * latent_w, num_channels_latents) + if latents is not None: + return latents.to(device=device, dtype=dtype) + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + ): + if height is None or width is None: + raise ValueError("height and width must be provided (or use base_resolution + aspect_ratio).") + if height % self.vae_scale_factor or width % self.vae_scale_factor: + raise ValueError( + f"height and width must be divisible by {self.vae_scale_factor}; got ({height}, {width})." + ) + if prompt is None and prompt_embeds is None: + raise ValueError("Either `prompt` or `prompt_embeds` must be provided.") + if callback_on_step_end_tensor_inputs is not None: + for k in callback_on_step_end_tensor_inputs: + if k not in self._callback_tensor_inputs: + raise ValueError( + f"callback_on_step_end_tensor_inputs entry {k!r} is not in {self._callback_tensor_inputs}." + ) + + @staticmethod + def _patchify_latents(latents: torch.Tensor) -> torch.Tensor: + b, c, h, w = latents.shape + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + return latents.reshape(b, c * 4, h // 2, w // 2) + + @staticmethod + def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + b, c, h, w = latents.shape + latents = latents.reshape(b, c // 4, 2, 2, h, w) + latents = latents.permute(0, 1, 4, 2, 5, 3) + return latents.reshape(b, c // 4, h * 2, w * 2) + + def _decode(self, latents: torch.Tensor, latent_h: int, latent_w: int): + b = latents.shape[0] + p1, p2 = 2, 2 + h, w = latent_h, latent_w + c = latents.shape[-1] // (p1 * p2) + latents = latents.reshape(b, h, w, c, p1, p2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(b, c, h * p1, w * p2) + latents = latents.to(self.vae.dtype) + + bn = self.vae.bn + mean = bn.running_mean.view(1, -1, 1, 1) + var = bn.running_var.view(1, -1, 1, 1) + std = torch.sqrt(var + self.vae.config.batch_norm_eps) + shift = (-mean).to(device=latents.device, dtype=latents.dtype) + scale = (1.0 / std).to(device=latents.device, dtype=latents.dtype) + x = self._patchify_latents(latents) + x = x / scale - shift + x = self._unpatchify_latents(x) + return self.vae.decode(x).sample + + @staticmethod + def _to_pil(image: torch.Tensor): + from PIL import Image + + image = image.clamp(-1.0, 1.0) + image = (image + 1.0) * (255.0 / 2.0) + image = image.permute(0, 2, 3, 1).to(device="cpu", dtype=torch.uint8).numpy() + return [Image.fromarray(im) for im in image] + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, list[str]] = None, + negative_prompt: Union[str, list[str]] = "", + height: Optional[int] = None, + width: Optional[int] = None, + base_resolution: Optional[int] = None, + aspect_ratio: Optional[str] = None, + num_inference_steps: int = 50, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[list[torch.Tensor]] = None, + prompt_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[list[torch.Tensor]] = None, + negative_prompt_mask: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable] = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*, defaults to `""`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.default_sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size): + The width in pixels of the generated image. + base_resolution (`int`, *optional*): + The base resolution for the resolution bucket. When provided together with `aspect_ratio`, overrides + `height` and `width`. Supported values: 1024, 1440. + aspect_ratio (`str`, *optional*): + The aspect ratio for the resolution bucket. When provided together with `base_resolution`, overrides + `height` and `width`. Supported values: "1:2", "9:16", "2:3", "3:4", "1:1", "4:3", "3:2", "16:9", + "2:1". + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a + latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If + not provided, text embeddings will be generated from `prompt` input argument. Must be a list of + tensors, one per selected text encoder layer. + prompt_mask (`torch.Tensor`, *optional*): + Boolean mask for the prompt embeddings. Must be provided when passing `prompt_embeds`. + negative_prompt_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated negative text embeddings. Must be a list of tensors, one per selected text encoder + layer. + negative_prompt_mask (`torch.Tensor`, *optional*): + Boolean mask for the negative prompt embeddings. Must be provided when passing + `negative_prompt_embeds`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.ndarray`) or `"latent"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.lens.LensPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during inference. + callback_on_step_end_tensor_inputs (`list`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the text encoder. + + Examples: + + Returns: + [`~pipelines.lens.LensPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.lens.LensPipelineOutput`] is returned, otherwise a `tuple` + is returned where the first element is a list of the generated images. + """ + if base_resolution is not None and aspect_ratio is not None: + height, width = resolve_resolution(base_resolution, aspect_ratio) + elif height is None or width is None: + height = width = self.default_sample_size + + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + device = self._execution_device + dtype = self.transformer.dtype + self._guidance_scale = guidance_scale + + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + prompt_mask=prompt_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_mask=negative_prompt_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self._align_text_features( + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask + ) + + encoder_features = [ + torch.cat([pf, nf], dim=0).to(dtype=dtype) + for pf, nf in zip(prompt_embeds, negative_prompt_embeds) + ] + encoder_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) + + batch_size = prompt_embeds[0].shape[0] + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + seq_len = latent_h * latent_w + latents = self.prepare_latents( + batch_size, self.latent_channels, height, width, + dtype=dtype, device=device, generator=generator, latents=latents, + ) + + mu = compute_empirical_mu(seq_len, num_inference_steps) + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) + self.scheduler.set_timesteps(sigmas=sigmas, device=device, mu=mu) + self._num_timesteps = len(self.scheduler.timesteps) + + img_shapes = [(1, latent_h, latent_w)] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(self.scheduler.timesteps): + timestep = t.expand(batch_size * 2).to(latents.dtype) + hidden_states = latents.repeat(2, 1, 1) + + noise = self.transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_features, + encoder_hidden_states_mask=encoder_mask, + timestep=timestep / 1000, + img_shapes=img_shapes, + return_dict=False, + )[0] + + cond, uncond = noise.chunk(2) + comb = uncond + self._guidance_scale * (cond - uncond) + cond_norm = torch.norm(cond, dim=-1, keepdim=True) + comb_norm = torch.norm(comb, dim=-1, keepdim=True) + scale = torch.where( + comb_norm > 0, + cond_norm / comb_norm.clamp_min(1e-12), + torch.ones_like(comb_norm), + ) + noise_pred = comb * scale + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + cb_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + cb_out = callback_on_step_end(self, i, t, cb_kwargs) + latents = cb_out.pop("latents", latents) + + progress_bar.update() + + if output_type == "latent": + images = latents + else: + decoded = self._decode(latents, latent_h, latent_w) + if output_type == "pil": + images = self._to_pil(decoded) + elif output_type == "np": + decoded = decoded.clamp(-1.0, 1.0) + decoded = (decoded + 1.0) * 0.5 + images = decoded.permute(0, 2, 3, 1).to("cpu", torch.float32).numpy() + elif output_type == "pt": + images = decoded + else: + raise ValueError(f"output_type must be one of 'pil', 'np', 'pt', 'latent'; got {output_type!r}.") + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return LensPipelineOutput(images=images) diff --git a/src/diffusers/pipelines/lens/pipeline_output.py b/src/diffusers/pipelines/lens/pipeline_output.py new file mode 100644 index 000000000000..1543ca0ec557 --- /dev/null +++ b/src/diffusers/pipelines/lens/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class LensPipelineOutput(BaseOutput): + """ + Output class for Lens image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + List of denoised PIL images of length `batch_size`, numpy array, or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/tests/models/transformers/test_models_transformer_lens.py b/tests/models/transformers/test_models_transformer_lens.py new file mode 100644 index 000000000000..0a91709e819e --- /dev/null +++ b/tests/models/transformers/test_models_transformer_lens.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import LensTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + ContextParallelTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class LensTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return LensTransformer2DModel + + @property + def pretrained_model_name_or_path(self): + return "" + + @property + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def model_split_percents(self) -> list: + return [0.1, 0.1, 0.1] + + def get_init_dict(self) -> dict: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 1, + "attention_head_dim": 20, + "num_attention_heads": 1, + "inner_dim": 20, + "enc_hidden_dim": 32, + "axes_dims_rope": [4, 8, 8], + "gate_mlp": True, + "rms_norm": True, + "multi_layer_encoder_feature": True, + "selected_layer_index": (0,), + } + + def get_dummy_inputs(self, batch_size: int = 1) -> dict: + height = width = 4 + num_latent_channels = 16 + text_seq_len = 8 + enc_hidden_dim = 32 + + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "encoder_hidden_states": [ + randn_tensor( + (batch_size, text_seq_len, enc_hidden_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ) + ], + "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size), + "img_shapes": [(1, height, width)], + } + + @property + def input_shape(self) -> tuple: + return (16, 16) + + @property + def output_shape(self) -> tuple: + return (16, 16) + + +class TestLensTransformerModel(LensTransformerTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + pytest.skip("Tolerance requirements too high for meaningful test") + + def test_model_parallelism(self, tmp_path): + pytest.skip("Tiny Lens test config does not split meaningfully across multiple GPUs") + + +class TestLensTransformerMemory(LensTransformerTesterConfig, MemoryTesterMixin): + def test_layerwise_casting_memory(self): + pytest.skip("Tiny Lens test config does not give a stable layerwise-casting memory ordering signal") + + def test_cpu_offload(self, tmp_path): + pytest.skip("Tiny Lens test config does not split meaningfully for CPU offload") + + def test_disk_offload_without_safetensors(self, tmp_path): + pytest.skip("Tiny Lens test config does not split meaningfully for disk offload") + + def test_disk_offload_with_safetensors(self, tmp_path): + pytest.skip("Tiny Lens test config does not split meaningfully for disk offload") + + +class TestLensTransformerTorchCompile(LensTransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + num_latent_channels = 16 + text_seq_len = 8 + enc_hidden_dim = 32 + batch_size = 1 + + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ), + "encoder_hidden_states": [ + randn_tensor( + (batch_size, text_seq_len, enc_hidden_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ) + ], + "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size), + "img_shapes": [(1, height, width)], + } + + +@pytest.mark.skip(reason="Tiny Lens test config is not suitable for context-parallel coverage") +class TestLensTransformerContextParallel(LensTransformerTesterConfig, ContextParallelTesterMixin): + pass + + +class TestLensTransformerTraining(LensTransformerTesterConfig, TrainingTesterMixin): + pass + + +class TestLensTransformerAttention(LensTransformerTesterConfig, AttentionTesterMixin): + def test_fuse_unfuse_qkv_projections(self): + pytest.skip("LensJointAttention does not use the generic to_q/to_k/to_v fusion path") + + def test_attention_processor_count_mismatch_raises_error(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + pytest.skip("Model does not support setting attention processors.") + + current_processors = model.attn_processors + if len(current_processors) <= 1: + pytest.skip("Lens test config exposes only one attention processor.") + + return super().test_attention_processor_count_mismatch_raises_error() diff --git a/tests/pipelines/lens/__init__.py b/tests/pipelines/lens/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/lens/test_pipeline_lens.py b/tests/pipelines/lens/test_pipeline_lens.py new file mode 100644 index 000000000000..ca15de857ab5 --- /dev/null +++ b/tests/pipelines/lens/test_pipeline_lens.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2025 Microsoft and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from tokenizers import Tokenizer +from tokenizers.decoders import WordPiece as WordPieceDecoder +from tokenizers.models import WordPiece +from tokenizers.pre_tokenizers import Whitespace +from transformers import GptOssConfig, PreTrainedTokenizerFast + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + LensPipeline, + LensTransformer2DModel, +) +from diffusers.pipelines.lens.pipeline_lens import LensGptOssEncoder +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +def _build_dummy_gptoss_tokenizer(): + vocab = { + "[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "<|return|>": 4, + "hello": 5, "world": 6, "a": 7, "cat": 8, "the": 9, "is": 10, + "on": 11, "mat": 12, "painting": 13, "of": 14, "squirrel": 15, + "eating": 16, "burger": 17, "Describe": 18, "image": 19, + } + tok = Tokenizer(WordPiece(vocab, unk_token="[UNK]")) + tok.pre_tokenizer = Whitespace() + tok.decoder = WordPieceDecoder() + fast = PreTrainedTokenizerFast(tokenizer_object=tok) + fast.pad_token = "[PAD]" + fast.eos_token = "[SEP]" + fast.chat_template = ( + "{% for message in messages %}" + "{{ message.content }}" + "{% if not loop.last %} {% endif %}" + "{% endfor %}" + ) + return fast + + +class LensPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LensPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True + test_group_offloading = True + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1): + torch.manual_seed(0) + transformer = LensTransformer2DModel( + patch_size=2, + in_channels=64, + out_channels=16, + num_layers=num_layers, + attention_head_dim=20, + num_attention_heads=1, + inner_dim=20, + enc_hidden_dim=32, + axes_dims_rope=(4, 8, 8), + gate_mlp=True, + rms_norm=True, + multi_layer_encoder_feature=True, + selected_layer_index=(0,), + ) + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + latent_channels=16, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + layers_per_block=1, + norm_num_groups=32, + use_quant_conv=False, + use_post_quant_conv=False, + ) + scheduler = FlowMatchEulerDiscreteScheduler() + + config = GptOssConfig( + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=64, + head_dim=16, + layer_types=["full_attention", "full_attention"], + vocab_size=1000, + ) + text_encoder = LensGptOssEncoder(config) + tokenizer = _build_dummy_gptoss_tokenizer() + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 16, + "width": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 2, 2)) + expected_image = torch.randn(3, 2, 2) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name, tensor_value in callback_kwargs.items(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name, tensor_value in callback_kwargs.items(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + self.assertIsNotNone(output) + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + self.assertIsNotNone(output)