diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 4c1c220d58e8..fa652c071ae2 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -353,6 +353,8 @@
title: JoyImageEditTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
+ - local: api/models/lens_transformer2d
+ title: LensTransformer2DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
@@ -553,6 +555,8 @@
title: Kandinsky 5.0 Image
- local: api/pipelines/kolors
title: Kolors
+ - local: api/pipelines/lens
+ title: Lens
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
diff --git a/docs/source/en/api/models/lens_transformer2d.md b/docs/source/en/api/models/lens_transformer2d.md
new file mode 100644
index 000000000000..d19a3ba3befa
--- /dev/null
+++ b/docs/source/en/api/models/lens_transformer2d.md
@@ -0,0 +1,23 @@
+
+
+# LensTransformer2DModel
+
+A Transformer model for image-like data from [Lens](https://huggingface.co/microsoft/Lens).
+
+## LensTransformer2DModel
+
+[[autodoc]] LensTransformer2DModel
+
+## LensTransformer2DModelOutput
+
+[[autodoc]] models.transformers.transformer_lens.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/lens.md b/docs/source/en/api/pipelines/lens.md
new file mode 100644
index 000000000000..ac402be829c7
--- /dev/null
+++ b/docs/source/en/api/pipelines/lens.md
@@ -0,0 +1,52 @@
+
+
+# Lens
+
+
+
+
+Lens is a 3.8B-parameter foundational text-to-image model designed for efficient training and fast high-resolution generation. It combines dense-caption pre-training, mixed-resolution learning, GPT-OSS multi-layer text features, and the FLUX.2 semantic VAE to reach competitive quality with substantially less training compute than larger T2I models. For more details, please refer to the [model card](https://huggingface.co/microsoft/Lens).
+
+The abstract from the paper is:
+
+*Lens is a 3.8B-parameter foundational text-to-image model designed for efficient training and fast high-resolution generation. It combines dense-caption pre-training, mixed-resolution learning, GPT-OSS multi-layer text features, and the FLUX.2 semantic VAE to reach competitive quality with substantially less training compute than larger T2I models.*
+
+## Usage Example
+
+```python
+import torch
+from diffusers import LensPipeline
+
+pipe = LensPipeline.from_pretrained("microsoft/Lens", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+image = pipe(
+ prompt="A cat holding a sign that says hello world",
+ height=1440,
+ width=1440,
+ num_inference_steps=20,
+ guidance_scale=5.0,
+).images[0]
+image.save("lens.png")
+```
+
+## LensPipeline
+
+[[autodoc]] LensPipeline
+
+- all
+- __call__
+
+## LensPipelineOutput
+
+[[autodoc]] pipelines.lens.pipeline_output.LensPipelineOutput
diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md
index 5e89f26fce54..0ba4a2854165 100644
--- a/docs/source/en/api/pipelines/overview.md
+++ b/docs/source/en/api/pipelines/overview.md
@@ -50,6 +50,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
| [Kandinsky 3](kandinsky3) | text2image, image2image |
| [Kolors](kolors) | text2image |
+| [Lens](lens) | text2image |
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index d75e2d9a5010..4bc7f0ff4eb1 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -260,6 +260,7 @@
"JoyImageEditTransformer3DModel",
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
+ "LensTransformer2DModel",
"LatteTransformer3DModel",
"LongCatAudioDiTTransformer",
"LongCatAudioDiTVae",
@@ -621,6 +622,7 @@
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
"LattePipeline",
+ "LensPipeline",
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
@@ -1096,6 +1098,7 @@
JoyImageEditTransformer3DModel,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
+ LensTransformer2DModel,
LatteTransformer3DModel,
LongCatAudioDiTTransformer,
LongCatAudioDiTVae,
@@ -1432,6 +1435,7 @@
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
LattePipeline,
+ LensPipeline,
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index e661dd1002a2..78c0361bf401 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -119,6 +119,7 @@
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
+ _import_structure["transformers.transformer_lens"] = ["LensTransformer2DModel"]
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -248,6 +249,7 @@
HunyuanVideoTransformer3DModel,
JoyImageEditTransformer3DModel,
Kandinsky5Transformer3DModel,
+ LensTransformer2DModel,
LatteTransformer3DModel,
LongCatAudioDiTTransformer,
LongCatImageTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 98558c22d7df..228eb8b9d51d 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -41,6 +41,7 @@
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_joyimage import JoyImageEditTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
+ from .transformer_lens import LensTransformer2DModel
from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_lens.py b/src/diffusers/models/transformers/transformer_lens.py
new file mode 100644
index 000000000000..3b7eee5996eb
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_lens.py
@@ -0,0 +1,622 @@
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import apply_lora_scale, logging
+from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def apply_rotary_emb_lens(
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(1)
+ x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
+ return x_out.type_as(x)
+
+
+class GateMLP(nn.Module):
+ def __init__(self, dim: int, hidden_dim: int) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class LensTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim: int) -> None:
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
+ proj = self.time_proj(timestep)
+ return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
+
+
+class LensEmbedRope(nn.Module):
+ def __init__(self, theta: int, axes_dim: list[int], scale_rope: bool = False) -> None:
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+ self.scale_rope = scale_rope
+ pos_index = torch.arange(4096)
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
+ self.register_buffer(
+ "pos_freqs",
+ torch.cat([self._rope_params(pos_index, d, theta) for d in axes_dim], dim=1),
+ persistent=False,
+ )
+ self.register_buffer(
+ "neg_freqs",
+ torch.cat([self._rope_params(neg_index, d, theta) for d in axes_dim], dim=1),
+ persistent=False,
+ )
+
+ @staticmethod
+ def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
+ )
+ return torch.polar(torch.ones_like(freqs), freqs)
+
+ @lru_cache_unless_export(maxsize=None)
+ def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
+ return self.pos_freqs.to(device), self.neg_freqs.to(device)
+
+ def forward(
+ self,
+ video_fhw: list[tuple[int, int, int]] | tuple[int, int, int],
+ txt_seq_len: int | torch.Tensor,
+ device: torch.device = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(video_fhw, list) and len(video_fhw) > 1:
+ first_fhw = video_fhw[0]
+ if not all(fhw == first_fhw for fhw in video_fhw):
+ logger.warning(
+ "Batch inference with variable-sized images is not currently supported in LensEmbedRope. "
+ "All images in the batch should have the same dimensions (frame, height, width). "
+ f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
+ "for RoPE computation, which may lead to incorrect results for other images in the batch."
+ )
+
+ if isinstance(video_fhw, list):
+ video_fhw = video_fhw[0]
+ if not isinstance(video_fhw, list):
+ video_fhw = [video_fhw]
+
+ vid_freqs = []
+ max_vid_index = 0
+ for idx, fhw in enumerate(video_fhw):
+ frame, height, width = fhw
+ video_freq = self._compute_video_freqs(frame, height, width, idx, device)
+ vid_freqs.append(video_freq)
+
+ if self.scale_rope:
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
+ else:
+ max_vid_index = max(height, width, max_vid_index)
+
+ max_txt_seq_len_int = int(txt_seq_len)
+ pos_freqs_device, _ = self._get_device_freqs(device)
+ txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
+ vid_freqs = torch.cat(vid_freqs, dim=0)
+
+ return vid_freqs, txt_freqs
+
+ @lru_cache_unless_export(maxsize=128)
+ def _compute_video_freqs(
+ self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
+ ) -> torch.Tensor:
+ seq_lens = frame * height * width
+ pos_freqs, neg_freqs = (
+ self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
+ )
+
+ freqs_pos = pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
+ freqs_neg = neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
+
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
+ if self.scale_rope:
+ freqs_height = torch.cat(
+ [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
+ )
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = torch.cat(
+ [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
+ )
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
+ else:
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
+
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
+ return freqs.clone().contiguous()
+
+
+class LensDoubleStreamAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: "LensJointAttention",
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ attention_mask: torch.FloatTensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ bsz, seq_img, _ = hidden_states.shape
+ seq_txt = encoder_hidden_states.shape[1]
+
+ img_qkv = attn.img_qkv(hidden_states).view(bsz, seq_img, 3, attn.heads, attn.dim_head)
+ txt_qkv = attn.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, attn.heads, attn.dim_head)
+ img_q, img_k, img_v = img_qkv.unbind(dim=2)
+ txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
+
+ if attn.norm_q is not None:
+ img_q = attn.norm_q(img_q)
+ if attn.norm_k is not None:
+ img_k = attn.norm_k(img_k)
+ if attn.norm_added_q is not None:
+ txt_q = attn.norm_added_q(txt_q)
+ if attn.norm_added_k is not None:
+ txt_k = attn.norm_added_k(txt_k)
+
+ if image_rotary_emb is not None:
+ img_freqs, txt_freqs = image_rotary_emb
+ if img_freqs.shape[0] < seq_img:
+ raise ValueError(
+ f"Image RoPE length {img_freqs.shape[0]} is shorter than image sequence length {seq_img}."
+ )
+ img_freqs = img_freqs[:seq_img]
+ img_q = apply_rotary_emb_lens(img_q, img_freqs)
+ img_k = apply_rotary_emb_lens(img_k, img_freqs)
+ if seq_txt > 0:
+ if txt_freqs.shape[0] < seq_txt:
+ raise ValueError(
+ f"Text RoPE length {txt_freqs.shape[0]} is shorter than text sequence length {seq_txt}."
+ )
+ txt_freqs = txt_freqs[:seq_txt]
+ txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
+ txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
+
+ joint_query = torch.cat([img_q, txt_q], dim=1)
+ joint_key = torch.cat([img_k, txt_k], dim=1)
+ joint_value = torch.cat([img_v, txt_v], dim=1)
+
+ joint_hidden_states = dispatch_attention_fn(
+ joint_query,
+ joint_key,
+ joint_value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
+
+ img_attn_output = joint_hidden_states[:, :seq_img, :]
+ txt_attn_output = joint_hidden_states[:, seq_img:, :]
+
+ img_attn_output = attn.to_out[0](img_attn_output.contiguous())
+ if len(attn.to_out) > 1:
+ img_attn_output = attn.to_out[1](img_attn_output)
+
+ txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
+
+ return img_attn_output, txt_attn_output
+
+
+class LensJointAttention(nn.Module, AttentionModuleMixin):
+ _default_processor_cls = LensDoubleStreamAttnProcessor
+ _available_processors = [LensDoubleStreamAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ added_kv_proj_dim: int,
+ dim_head: int = 64,
+ heads: int = 8,
+ out_dim: int | None = None,
+ eps: float = 1e-5,
+ ) -> None:
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.heads = self.inner_dim // dim_head
+ self.dim_head = dim_head
+ self.out_dim = out_dim if out_dim is not None else query_dim
+
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+
+ self.img_qkv = nn.Linear(query_dim, 3 * self.inner_dim, bias=True)
+ self.txt_qkv = nn.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True)
+
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, self.out_dim, bias=True), nn.Identity()])
+ self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=True)
+
+ self.set_processor(LensDoubleStreamAttnProcessor())
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ return self.processor(
+ self,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ )
+
+
+@maybe_allow_in_graph
+class LensTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ eps: float = 1e-6,
+ rms_norm: bool = False,
+ gate_mlp: bool = False,
+ ) -> None:
+ super().__init__()
+ self.attn = LensJointAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ eps=eps,
+ )
+
+ norm_cls = (lambda d: RMSNorm(d, eps=eps)) if rms_norm else (
+ lambda d: nn.LayerNorm(d, elementwise_affine=False, eps=eps)
+ )
+ if gate_mlp:
+ mlp_cls = lambda: GateMLP(dim, int(dim / 3 * 8))
+ else:
+ mlp_cls = lambda: FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.img_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
+ self.img_norm1 = norm_cls(dim)
+ self.img_norm2 = norm_cls(dim)
+ self.img_mlp = mlp_cls()
+
+ self.txt_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
+ self.txt_norm1 = norm_cls(dim)
+ self.txt_norm2 = norm_cls(dim)
+ self.txt_mlp = mlp_cls()
+
+ @staticmethod
+ def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
+ txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
+
+ img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
+ txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ block_attention_mask = joint_attention_kwargs.pop("attention_mask", attention_mask)
+ img_attn, txt_attn = self.attn(
+ hidden_states=img_modulated,
+ encoder_hidden_states=txt_modulated,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=block_attention_mask,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = hidden_states + img_gate1 * img_attn
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
+
+ img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
+ hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
+
+ txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
+
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class LensTransformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ """
+ The Transformer model introduced in Lens.
+
+ Reference: https://huggingface.co/microsoft/Lens
+
+ Args:
+ patch_size (`int`, defaults to `2`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `128`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `32`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `48`):
+ The number of layers of dual stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ inner_dim (`int`, defaults to `1536`):
+ The inner dimension of the transformer. If not specified, it defaults to
+ `num_attention_heads * attention_head_dim`.
+ enc_hidden_dim (`int`, defaults to `2880`):
+ The hidden dimension of the text encoder outputs.
+ axes_dims_rope (`tuple[int, int, int]`, defaults to `(8, 28, 28)`):
+ The dimensions to use for the rotary positional embeddings for frame, height, and width axes.
+ gate_mlp (`bool`, defaults to `True`):
+ Whether to use a gated MLP (SwiGLU) in the transformer blocks.
+ rms_norm (`bool`, defaults to `True`):
+ Whether to use RMS normalization instead of LayerNorm in the transformer blocks.
+ multi_layer_encoder_feature (`bool`, defaults to `True`):
+ Whether to use multi-layer text encoder features by selecting specific layers.
+ selected_layer_index (`tuple[int, ...]`, defaults to `(5, 11, 17, 23)`):
+ The indices of the text encoder layers to select when `multi_layer_encoder_feature` is True.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["LensTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["LensTransformerBlock"]
+ _cp_plan = {
+ "transformer_blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 128,
+ out_channels: int | None = 32,
+ num_layers: int = 48,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 24,
+ inner_dim: int = 1536,
+ enc_hidden_dim: int = 2880,
+ axes_dims_rope: tuple[int, int, int] = (8, 28, 28),
+ gate_mlp: bool = True,
+ rms_norm: bool = True,
+ multi_layer_encoder_feature: bool = True,
+ selected_layer_index: tuple[int, ...] = (5, 11, 17, 23),
+ ) -> None:
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.multi_layer_encoder_feature = multi_layer_encoder_feature
+ self.selected_layer_index = list(selected_layer_index)
+
+ self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
+ self.time_text_embed = LensTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ if self.multi_layer_encoder_feature:
+ self.txt_norm = nn.ModuleList(
+ [RMSNorm(enc_hidden_dim, eps=1e-5) for _ in self.selected_layer_index]
+ )
+ self.txt_in = nn.Linear(enc_hidden_dim * len(self.selected_layer_index), self.inner_dim)
+ else:
+ self.txt_norm = RMSNorm(enc_hidden_dim, eps=1e-5)
+ self.txt_in = nn.Linear(enc_hidden_dim, self.inner_dim)
+
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ LensTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ rms_norm=rms_norm,
+ gate_mlp=gate_mlp,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ @apply_lora_scale("attention_kwargs")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | list[torch.Tensor],
+ encoder_hidden_states_mask: torch.Tensor | None = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: list[tuple[int, int, int]] | None = None,
+ attention_kwargs: dict[str, Any] | None = None,
+ return_dict: bool = True,
+
+ ) -> torch.Tensor | Transformer2DModelOutput:
+ """
+ The [`LensTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` or `list[torch.Tensor]`):
+ Conditional embeddings computed from the input conditions such as prompts. When
+ `multi_layer_encoder_feature` is True, a list of per-layer text tensors is expected.
+ encoder_hidden_states_mask (`torch.Tensor`, *optional*):
+ Boolean mask for the encoder hidden states, where `True` indicates valid tokens.
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ img_shapes (`list[tuple[int, int, int]]`, *optional*):
+ List of (frame, height, width) tuples for each image in the batch, used to compute
+ rotary positional embeddings.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ bsz, img_len, _ = hidden_states.shape
+
+ if self.multi_layer_encoder_feature:
+ if not isinstance(encoder_hidden_states, (list, tuple)):
+ raise ValueError(
+ "multi_layer_encoder_feature=True expects a list of per-layer text tensors."
+ )
+ if len(encoder_hidden_states) != len(self.selected_layer_index):
+ raise ValueError(
+ f"Expected {len(self.selected_layer_index)} text feature layers, "
+ f"got {len(encoder_hidden_states)}."
+ )
+ text_seq_len = encoder_hidden_states[0].shape[1]
+ else:
+ if not isinstance(encoder_hidden_states, torch.Tensor):
+ raise ValueError(
+ "multi_layer_encoder_feature=False expects a single text feature tensor."
+ )
+ text_seq_len = encoder_hidden_states.shape[1]
+
+ attention_mask = None
+ if encoder_hidden_states_mask is not None:
+ attention_mask = self._build_joint_attention_mask(encoder_hidden_states_mask, img_len)
+
+ hidden_states = self.img_in(hidden_states)
+ timestep = timestep.to(hidden_states.dtype)
+
+ if self.multi_layer_encoder_feature:
+ normed = [
+ self.txt_norm[i](encoder_hidden_states[i])
+ for i in range(len(self.selected_layer_index))
+ ]
+ encoder_hidden_states = torch.cat(normed, dim=-1)
+ else:
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ temb = self.time_text_embed(timestep, hidden_states)
+
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_len=text_seq_len, device=hidden_states.device)
+
+ block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
+ if attention_mask is not None:
+ block_attention_kwargs["attention_mask"] = attention_mask
+
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ None,
+ block_attention_kwargs,
+ )
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=block_attention_kwargs,
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ @staticmethod
+ def _build_joint_attention_mask(
+ text_mask: torch.Tensor, img_len: int
+ ) -> torch.Tensor:
+ if text_mask.dtype != torch.bool:
+ text_mask = text_mask.bool()
+ bsz = text_mask.shape[0]
+ img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
+ joint = torch.cat([img_ones, text_mask], dim=1)
+ additive = torch.zeros_like(joint, dtype=torch.float32)
+ additive.masked_fill_(~joint, float("-inf"))
+ return additive[:, None, None, :]
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 720548e38fd4..45528e903915 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -324,6 +324,7 @@
]
)
_import_structure["latte"] = ["LattePipeline"]
+ _import_structure["lens"] = ["LensPipeline"]
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
_import_structure["ltx"] = [
"LTXPipeline",
@@ -783,6 +784,7 @@
)
from .latent_diffusion import LDMTextToImagePipeline
from .latte import LattePipeline
+ from .lens import LensPipeline
from .ledits_pp import (
LEditsPPDiffusionPipelineOutput,
LEditsPPInversionPipelineOutput,
diff --git a/src/diffusers/pipelines/lens/__init__.py b/src/diffusers/pipelines/lens/__init__.py
new file mode 100644
index 000000000000..31ad5fe7b960
--- /dev/null
+++ b/src/diffusers/pipelines/lens/__init__.py
@@ -0,0 +1,65 @@
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["LensPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lens"] = ["LensGptOssEncoder", "LensPipeline"]
+
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_lens import LensGptOssEncoder, LensPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lens/pipeline_lens.py b/src/diffusers/pipelines/lens/pipeline_lens.py
new file mode 100644
index 000000000000..4616e69c949f
--- /dev/null
+++ b/src/diffusers/pipelines/lens/pipeline_lens.py
@@ -0,0 +1,680 @@
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Sequence, Union
+
+import numpy as np
+import torch
+from transformers import PreTrainedTokenizerBase
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLFlux2, LensTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LensPipelineOutput
+from transformers.masking_utils import (
+ create_causal_mask,
+ create_sliding_window_causal_mask,
+)
+from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
+
+
+
+
+logger = logging.get_logger(__name__)
+
+
+class LensGptOssEncoder(GptOssForCausalLM):
+ """`GptOssForCausalLM` subclass that exposes selected hidden states.
+
+ This text encoder extracts hidden states from specific intermediate layers of the
+ GPT-OSS model, enabling multi-layer text features for the Lens transformer. It
+ early-exits after the last selected layer to avoid unnecessary computation.
+
+ Call `set_selected_layers` before the first forward pass to configure which
+ layers to capture.
+ """
+
+ def set_selected_layers(self, layer_indices: Sequence[int]) -> None:
+ layers = [int(i) for i in layer_indices]
+ if not layers:
+ raise ValueError("layer_indices must be non-empty")
+ if len(set(layers)) != len(layers):
+ raise ValueError(f"layer_indices must be unique; got {layers}")
+ if min(layers) < 0 or max(layers) >= len(self.model.layers):
+ raise ValueError(
+ f"layer_indices out of range; got {layers}, "
+ f"model has {len(self.model.layers)} layers"
+ )
+ self._lens_selected_layers = layers
+ self._lens_max_layer = max(layers)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ ) -> list[torch.Tensor]:
+ """The [`LensGptOssEncoder`] forward method.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ Mask to avoid performing attention on padding token indices.
+
+ Returns:
+ `list[torch.Tensor]`: Hidden states at the configured selected layers.
+ """
+ if not hasattr(self, "_lens_selected_layers"):
+ raise RuntimeError("Call set_selected_layers(...) before forward().")
+
+ target_device = self.model.embed_tokens.weight.device
+ if target_device.type != "meta":
+ if input_ids is not None and input_ids.device != target_device:
+ input_ids = input_ids.to(target_device)
+ if attention_mask is not None and attention_mask.device != target_device:
+ attention_mask = attention_mask.to(target_device)
+
+ model = self.model
+ inputs_embeds = model.embed_tokens(input_ids)
+ position_ids = torch.arange(
+ inputs_embeds.shape[1], device=inputs_embeds.device
+ ).unsqueeze(0).expand_as(input_ids)
+
+ mask_kwargs = {
+ "config": model.config,
+ "inputs_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": None,
+ "position_ids": position_ids,
+ }
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ hidden_states = inputs_embeds
+ position_embeddings = model.rotary_emb(hidden_states, position_ids)
+
+ captured: list[torch.Tensor | None] = [None] * len(self._lens_selected_layers)
+ index_lookup = {idx: pos for pos, idx in enumerate(self._lens_selected_layers)}
+
+ for i, decoder_layer in enumerate(model.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[model.config.layer_types[i]],
+ position_embeddings=position_embeddings,
+ position_ids=position_ids,
+ past_key_values=None,
+ use_cache=False,
+ )
+ if i in index_lookup:
+ captured[index_lookup[i]] = hidden_states
+ if i == self._lens_max_layer:
+ break
+
+ for pos, layer_idx in enumerate(self._lens_selected_layers):
+ if captured[pos] is None:
+ raise RuntimeError(
+ f"Failed to capture hidden state for layer {layer_idx}"
+ )
+ return captured
+
+
+import transformers as _transformers
+
+if not hasattr(_transformers, "LensGptOssEncoder"):
+ _transformers.LensGptOssEncoder = LensGptOssEncoder
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LensPipeline
+
+ >>> pipe = LensPipeline.from_pretrained("microsoft/Lens", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> image = pipe(prompt, num_inference_steps=50).images[0]
+ >>> image.save("lens.png")
+ ```
+"""
+
+
+RESOLUTION_BUCKETS = {
+ 1024: {
+ "1:2": (1472, 736),
+ "9:16": (1376, 768),
+ "2:3": (1248, 832),
+ "3:4": (1152, 864),
+ "1:1": (1024, 1024),
+ "4:3": (864, 1152),
+ "3:2": (832, 1248),
+ "16:9": (768, 1376),
+ "2:1": (736, 1472),
+ },
+ 1440: {
+ "1:2": (2080, 1040),
+ "9:16": (1936, 1088),
+ "2:3": (1760, 1168),
+ "3:4": (1616, 1216),
+ "1:1": (1440, 1440),
+ "4:3": (1216, 1616),
+ "3:2": (1168, 1760),
+ "16:9": (1088, 1936),
+ "2:1": (1040, 2080),
+ },
+}
+
+
+def resolve_resolution(base_resolution: int, aspect_ratio: str):
+ if base_resolution not in RESOLUTION_BUCKETS:
+ raise ValueError(
+ f"Unsupported base_resolution={base_resolution}. "
+ f"Supported: {tuple(RESOLUTION_BUCKETS.keys())}"
+ )
+ table = RESOLUTION_BUCKETS[base_resolution]
+ if aspect_ratio not in table:
+ raise ValueError(
+ f"Unsupported aspect_ratio={aspect_ratio!r}. "
+ f"Supported: {tuple(RESOLUTION_BUCKETS[1024].keys())}"
+ )
+ return table[aspect_ratio]
+
+
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
+ a1, b1 = 8.73809524e-05, 1.89833333
+ a2, b2 = 0.00016927, 0.45666666
+ if image_seq_len > 4300:
+ return float(a2 * image_seq_len + b2)
+ m_200 = a2 * image_seq_len + b2
+ m_10 = a1 * image_seq_len + b1
+ a = (m_200 - m_10) / 190.0
+ b = m_200 - 200.0 * a
+ return float(a * num_steps + b)
+
+
+CHAT_SYSTEM = (
+ "Describe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background."
+)
+CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
+DEFAULT_TXT_OFFSET = 97
+
+
+class LensPipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ @property
+ def guidance_scale(self) -> float:
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self) -> int:
+ return self._num_timesteps
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLFlux2,
+ text_encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: LensTransformer2DModel,
+ ):
+ super().__init__()
+ self.register_modules(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ )
+ if self.tokenizer is not None and self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ if self.tokenizer is not None:
+ self.tokenizer.padding_side = "right"
+ self.vae_scale_factor = 16
+ self.latent_channels = self.transformer.config.in_channels if self.transformer is not None else None
+ self.txt_offset = DEFAULT_TXT_OFFSET
+
+ if self.text_encoder is not None and hasattr(self.text_encoder, "set_selected_layers"):
+ if self.transformer is not None:
+ selected_layers = list(self.transformer.config.selected_layer_index)
+ else:
+ selected_layers = [0]
+ self.text_encoder.set_selected_layers(selected_layers)
+ self.default_sample_size = 1024
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _build_chat_inputs(
+ self, prompts: list[str], max_sequence_length: int, device: torch.device
+ ):
+ rendered = []
+ for prompt in prompts:
+ conversation = [
+ {"role": "system", "content": CHAT_SYSTEM, "thinking": None},
+ {"role": "user", "content": prompt, "thinking": None},
+ {"role": "assistant", "thinking": CHAT_ASSISTANT_THINKING, "content": ""},
+ ]
+ text = self.tokenizer.apply_chat_template(
+ conversation, tokenize=False, add_generation_prompt=False
+ )
+ text = text.split("<|return|>")[0]
+ rendered.append(text)
+
+ encoded = self.tokenizer(
+ rendered,
+ padding=True,
+ truncation=True,
+ max_length=max_sequence_length,
+ return_tensors="pt",
+ add_special_tokens=True,
+ )
+ return encoded["input_ids"].to(device), encoded["attention_mask"].to(device)
+
+ def _get_text_embeddings(
+ self, prompts: list[str], max_sequence_length: int, device: torch.device
+ ):
+ input_ids, attn_mask = self._build_chat_inputs(prompts, max_sequence_length, device)
+ layer_outputs = self.text_encoder(input_ids, attention_mask=attn_mask)
+
+ offset = self.txt_offset
+ if input_ids.shape[1] > offset:
+ features = [feat[:, offset:, :].contiguous() for feat in layer_outputs]
+ mask = attn_mask[:, offset:].bool()
+ else:
+ zero_shape = (input_ids.shape[0], 0, layer_outputs[0].shape[-1])
+ features = [layer_outputs[0].new_zeros(zero_shape) for _ in layer_outputs]
+ mask = torch.zeros((input_ids.shape[0], 0), dtype=torch.bool, device=device)
+ return features, mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, list[str]],
+ negative_prompt: Union[str, list[str]] = "",
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[list[torch.Tensor]] = None,
+ prompt_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[list[torch.Tensor]] = None,
+ negative_prompt_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompts = None
+ if prompt_embeds is None:
+ prompts = [prompt] if isinstance(prompt, str) else list(prompt)
+ n = int(num_images_per_prompt)
+
+ negatives = None
+ if negative_prompt_embeds is None:
+ prompt_batch_size = len(prompts) if prompts is not None else prompt_embeds[0].shape[0]
+ if isinstance(negative_prompt, str):
+ negatives = [negative_prompt] * prompt_batch_size
+ else:
+ negatives = list(negative_prompt)
+ if len(negatives) == 1:
+ negatives = negatives * prompt_batch_size
+ if len(negatives) != prompt_batch_size:
+ raise ValueError(
+ "negative_prompt must be a string or a list of the same length as prompt"
+ )
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_mask = self._get_text_embeddings(prompts, max_sequence_length, device)
+ prompt_embeds, prompt_mask = self._repeat_for_n(prompt_embeds, prompt_mask, n)
+ elif prompt_mask is None:
+ raise ValueError("`prompt_mask` must be provided when passing `prompt_embeds`.")
+
+ if negative_prompt_embeds is None:
+ if all(isinstance(neg, str) and not neg.strip() for neg in negatives):
+ negative_prompt_embeds = [feat.new_zeros(feat.shape) for feat in prompt_embeds]
+ negative_prompt_mask = torch.zeros_like(prompt_mask, dtype=torch.bool)
+ else:
+ negative_prompt_embeds, negative_prompt_mask = self._get_text_embeddings(
+ negatives, max_sequence_length, device
+ )
+ negative_prompt_embeds, negative_prompt_mask = self._repeat_for_n(
+ negative_prompt_embeds, negative_prompt_mask, n
+ )
+ elif negative_prompt_mask is None:
+ raise ValueError(
+ "`negative_prompt_mask` must be provided when passing `negative_prompt_embeds`."
+ )
+ return prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask
+
+ @staticmethod
+ def _repeat_for_n(features: list[torch.Tensor], mask: torch.Tensor, n: int):
+ if n == 1:
+ return features, mask
+ features = [f.repeat_interleave(n, dim=0) for f in features]
+ mask = mask.repeat_interleave(n, dim=0)
+ return features, mask
+
+ @staticmethod
+ def _align_text_features(
+ pos_features: list[torch.Tensor],
+ pos_mask: torch.Tensor,
+ neg_features: list[torch.Tensor],
+ neg_mask: torch.Tensor,
+ ):
+ seq_pos = pos_features[0].shape[1]
+ seq_neg = neg_features[0].shape[1]
+ target = max(seq_pos, seq_neg)
+
+ def pad(features, cur):
+ if cur == target:
+ return features
+ pad_len = target - cur
+ return [
+ torch.cat([feat, feat.new_zeros((feat.shape[0], pad_len, feat.shape[-1]))], dim=1)
+ for feat in features
+ ]
+
+ def pad_mask(mask, cur):
+ if cur == target:
+ return mask
+ return torch.cat(
+ [mask, torch.zeros((mask.shape[0], target - cur), dtype=torch.bool, device=mask.device)], dim=1
+ )
+
+ pos_features = pad(pos_features, seq_pos)
+ neg_features = pad(neg_features, seq_neg)
+ pos_mask = pad_mask(pos_mask.bool(), seq_pos)
+ neg_mask = pad_mask(neg_mask.bool(), seq_neg)
+ return pos_features, pos_mask, neg_features, neg_mask
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ latent_h = height // self.vae_scale_factor
+ latent_w = width // self.vae_scale_factor
+ shape = (batch_size, latent_h * latent_w, num_channels_latents)
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+ return randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ ):
+ if height is None or width is None:
+ raise ValueError("height and width must be provided (or use base_resolution + aspect_ratio).")
+ if height % self.vae_scale_factor or width % self.vae_scale_factor:
+ raise ValueError(
+ f"height and width must be divisible by {self.vae_scale_factor}; got ({height}, {width})."
+ )
+ if prompt is None and prompt_embeds is None:
+ raise ValueError("Either `prompt` or `prompt_embeds` must be provided.")
+ if callback_on_step_end_tensor_inputs is not None:
+ for k in callback_on_step_end_tensor_inputs:
+ if k not in self._callback_tensor_inputs:
+ raise ValueError(
+ f"callback_on_step_end_tensor_inputs entry {k!r} is not in {self._callback_tensor_inputs}."
+ )
+
+ @staticmethod
+ def _patchify_latents(latents: torch.Tensor) -> torch.Tensor:
+ b, c, h, w = latents.shape
+ latents = latents.view(b, c, h // 2, 2, w // 2, 2)
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
+ return latents.reshape(b, c * 4, h // 2, w // 2)
+
+ @staticmethod
+ def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor:
+ b, c, h, w = latents.shape
+ latents = latents.reshape(b, c // 4, 2, 2, h, w)
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
+ return latents.reshape(b, c // 4, h * 2, w * 2)
+
+ def _decode(self, latents: torch.Tensor, latent_h: int, latent_w: int):
+ b = latents.shape[0]
+ p1, p2 = 2, 2
+ h, w = latent_h, latent_w
+ c = latents.shape[-1] // (p1 * p2)
+ latents = latents.reshape(b, h, w, c, p1, p2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+ latents = latents.reshape(b, c, h * p1, w * p2)
+ latents = latents.to(self.vae.dtype)
+
+ bn = self.vae.bn
+ mean = bn.running_mean.view(1, -1, 1, 1)
+ var = bn.running_var.view(1, -1, 1, 1)
+ std = torch.sqrt(var + self.vae.config.batch_norm_eps)
+ shift = (-mean).to(device=latents.device, dtype=latents.dtype)
+ scale = (1.0 / std).to(device=latents.device, dtype=latents.dtype)
+ x = self._patchify_latents(latents)
+ x = x / scale - shift
+ x = self._unpatchify_latents(x)
+ return self.vae.decode(x).sample
+
+ @staticmethod
+ def _to_pil(image: torch.Tensor):
+ from PIL import Image
+
+ image = image.clamp(-1.0, 1.0)
+ image = (image + 1.0) * (255.0 / 2.0)
+ image = image.permute(0, 2, 3, 1).to(device="cpu", dtype=torch.uint8).numpy()
+ return [Image.fromarray(im) for im in image]
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, list[str]] = None,
+ negative_prompt: Union[str, list[str]] = "",
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ base_resolution: Optional[int] = None,
+ aspect_ratio: Optional[str] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[list[torch.Tensor]] = None,
+ prompt_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[list[torch.Tensor]] = None,
+ negative_prompt_mask: Optional[torch.Tensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable] = None,
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `list[str]`, *optional*, defaults to `""`):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.default_sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.default_sample_size):
+ The width in pixels of the generated image.
+ base_resolution (`int`, *optional*):
+ The base resolution for the resolution bucket. When provided together with `aspect_ratio`, overrides
+ `height` and `width`. Supported values: 1024, 1440.
+ aspect_ratio (`str`, *optional*):
+ The aspect ratio for the resolution bucket. When provided together with `base_resolution`, overrides
+ `height` and `width`. Supported values: "1:2", "9:16", "2:3", "3:4", "1:1", "4:3", "3:2", "16:9",
+ "2:1".
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a
+ latents tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`list[torch.Tensor]`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If
+ not provided, text embeddings will be generated from `prompt` input argument. Must be a list of
+ tensors, one per selected text encoder layer.
+ prompt_mask (`torch.Tensor`, *optional*):
+ Boolean mask for the prompt embeddings. Must be provided when passing `prompt_embeds`.
+ negative_prompt_embeds (`list[torch.Tensor]`, *optional*):
+ Pre-generated negative text embeddings. Must be a list of tensors, one per selected text encoder
+ layer.
+ negative_prompt_mask (`torch.Tensor`, *optional*):
+ Boolean mask for the negative prompt embeddings. Must be provided when passing
+ `negative_prompt_embeds`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.ndarray`) or `"latent"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.lens.LensPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising step during inference.
+ callback_on_step_end_tensor_inputs (`list`, *optional*, defaults to `["latents"]`):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length to use with the text encoder.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.lens.LensPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.lens.LensPipelineOutput`] is returned, otherwise a `tuple`
+ is returned where the first element is a list of the generated images.
+ """
+ if base_resolution is not None and aspect_ratio is not None:
+ height, width = resolve_resolution(base_resolution, aspect_ratio)
+ elif height is None or width is None:
+ height = width = self.default_sample_size
+
+ self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+ self._guidance_scale = guidance_scale
+
+ prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_mask=prompt_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_prompt_mask=negative_prompt_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self._align_text_features(
+ prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask
+ )
+
+ encoder_features = [
+ torch.cat([pf, nf], dim=0).to(dtype=dtype)
+ for pf, nf in zip(prompt_embeds, negative_prompt_embeds)
+ ]
+ encoder_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0)
+
+ batch_size = prompt_embeds[0].shape[0]
+ latent_h = height // self.vae_scale_factor
+ latent_w = width // self.vae_scale_factor
+ seq_len = latent_h * latent_w
+ latents = self.prepare_latents(
+ batch_size, self.latent_channels, height, width,
+ dtype=dtype, device=device, generator=generator, latents=latents,
+ )
+
+ mu = compute_empirical_mu(seq_len, num_inference_steps)
+ sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
+ self.scheduler.set_timesteps(sigmas=sigmas, device=device, mu=mu)
+ self._num_timesteps = len(self.scheduler.timesteps)
+
+ img_shapes = [(1, latent_h, latent_w)]
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(self.scheduler.timesteps):
+ timestep = t.expand(batch_size * 2).to(latents.dtype)
+ hidden_states = latents.repeat(2, 1, 1)
+
+ noise = self.transformer(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_features,
+ encoder_hidden_states_mask=encoder_mask,
+ timestep=timestep / 1000,
+ img_shapes=img_shapes,
+ return_dict=False,
+ )[0]
+
+ cond, uncond = noise.chunk(2)
+ comb = uncond + self._guidance_scale * (cond - uncond)
+ cond_norm = torch.norm(cond, dim=-1, keepdim=True)
+ comb_norm = torch.norm(comb, dim=-1, keepdim=True)
+ scale = torch.where(
+ comb_norm > 0,
+ cond_norm / comb_norm.clamp_min(1e-12),
+ torch.ones_like(comb_norm),
+ )
+ noise_pred = comb * scale
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ cb_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
+ cb_out = callback_on_step_end(self, i, t, cb_kwargs)
+ latents = cb_out.pop("latents", latents)
+
+ progress_bar.update()
+
+ if output_type == "latent":
+ images = latents
+ else:
+ decoded = self._decode(latents, latent_h, latent_w)
+ if output_type == "pil":
+ images = self._to_pil(decoded)
+ elif output_type == "np":
+ decoded = decoded.clamp(-1.0, 1.0)
+ decoded = (decoded + 1.0) * 0.5
+ images = decoded.permute(0, 2, 3, 1).to("cpu", torch.float32).numpy()
+ elif output_type == "pt":
+ images = decoded
+ else:
+ raise ValueError(f"output_type must be one of 'pil', 'np', 'pt', 'latent'; got {output_type!r}.")
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (images,)
+ return LensPipelineOutput(images=images)
diff --git a/src/diffusers/pipelines/lens/pipeline_output.py b/src/diffusers/pipelines/lens/pipeline_output.py
new file mode 100644
index 000000000000..1543ca0ec557
--- /dev/null
+++ b/src/diffusers/pipelines/lens/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class LensPipelineOutput(BaseOutput):
+ """
+ Output class for Lens image generation pipelines.
+
+ Args:
+ images (`list[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
+ List of denoised PIL images of length `batch_size`, numpy array, or torch tensor of shape `(batch_size,
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents.
+ """
+
+ images: list[PIL.Image.Image] | np.ndarray
diff --git a/tests/models/transformers/test_models_transformer_lens.py b/tests/models/transformers/test_models_transformer_lens.py
new file mode 100644
index 000000000000..0a91709e819e
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_lens.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+
+from diffusers import LensTransformer2DModel
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..testing_utils import (
+ AttentionTesterMixin,
+ ContextParallelTesterMixin,
+ BaseModelTesterConfig,
+ MemoryTesterMixin,
+ ModelTesterMixin,
+ TorchCompileTesterMixin,
+ TrainingTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class LensTransformerTesterConfig(BaseModelTesterConfig):
+ @property
+ def model_class(self):
+ return LensTransformer2DModel
+
+ @property
+ def pretrained_model_name_or_path(self):
+ return ""
+
+ @property
+ def pretrained_model_kwargs(self):
+ return {"subfolder": "transformer"}
+
+ @property
+ def generator(self):
+ return torch.Generator("cpu").manual_seed(0)
+
+ @property
+ def main_input_name(self) -> str:
+ return "hidden_states"
+
+ @property
+ def model_split_percents(self) -> list:
+ return [0.1, 0.1, 0.1]
+
+ def get_init_dict(self) -> dict:
+ return {
+ "patch_size": 2,
+ "in_channels": 16,
+ "out_channels": 4,
+ "num_layers": 1,
+ "attention_head_dim": 20,
+ "num_attention_heads": 1,
+ "inner_dim": 20,
+ "enc_hidden_dim": 32,
+ "axes_dims_rope": [4, 8, 8],
+ "gate_mlp": True,
+ "rms_norm": True,
+ "multi_layer_encoder_feature": True,
+ "selected_layer_index": (0,),
+ }
+
+ def get_dummy_inputs(self, batch_size: int = 1) -> dict:
+ height = width = 4
+ num_latent_channels = 16
+ text_seq_len = 8
+ enc_hidden_dim = 32
+
+ return {
+ "hidden_states": randn_tensor(
+ (batch_size, height * width, num_latent_channels),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "encoder_hidden_states": [
+ randn_tensor(
+ (batch_size, text_seq_len, enc_hidden_dim),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ )
+ ],
+ "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size),
+ "img_shapes": [(1, height, width)],
+ }
+
+ @property
+ def input_shape(self) -> tuple:
+ return (16, 16)
+
+ @property
+ def output_shape(self) -> tuple:
+ return (16, 16)
+
+
+class TestLensTransformerModel(LensTransformerTesterConfig, ModelTesterMixin):
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
+ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
+ pytest.skip("Tolerance requirements too high for meaningful test")
+
+ def test_model_parallelism(self, tmp_path):
+ pytest.skip("Tiny Lens test config does not split meaningfully across multiple GPUs")
+
+
+class TestLensTransformerMemory(LensTransformerTesterConfig, MemoryTesterMixin):
+ def test_layerwise_casting_memory(self):
+ pytest.skip("Tiny Lens test config does not give a stable layerwise-casting memory ordering signal")
+
+ def test_cpu_offload(self, tmp_path):
+ pytest.skip("Tiny Lens test config does not split meaningfully for CPU offload")
+
+ def test_disk_offload_without_safetensors(self, tmp_path):
+ pytest.skip("Tiny Lens test config does not split meaningfully for disk offload")
+
+ def test_disk_offload_with_safetensors(self, tmp_path):
+ pytest.skip("Tiny Lens test config does not split meaningfully for disk offload")
+
+
+class TestLensTransformerTorchCompile(LensTransformerTesterConfig, TorchCompileTesterMixin):
+ @property
+ def different_shapes_for_compilation(self):
+ return [(4, 4), (4, 8), (8, 8)]
+
+ def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
+ num_latent_channels = 16
+ text_seq_len = 8
+ enc_hidden_dim = 32
+ batch_size = 1
+
+ return {
+ "hidden_states": randn_tensor(
+ (batch_size, height * width, num_latent_channels),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ ),
+ "encoder_hidden_states": [
+ randn_tensor(
+ (batch_size, text_seq_len, enc_hidden_dim),
+ generator=self.generator,
+ device=torch_device,
+ dtype=self.torch_dtype,
+ )
+ ],
+ "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size),
+ "img_shapes": [(1, height, width)],
+ }
+
+
+@pytest.mark.skip(reason="Tiny Lens test config is not suitable for context-parallel coverage")
+class TestLensTransformerContextParallel(LensTransformerTesterConfig, ContextParallelTesterMixin):
+ pass
+
+
+class TestLensTransformerTraining(LensTransformerTesterConfig, TrainingTesterMixin):
+ pass
+
+
+class TestLensTransformerAttention(LensTransformerTesterConfig, AttentionTesterMixin):
+ def test_fuse_unfuse_qkv_projections(self):
+ pytest.skip("LensJointAttention does not use the generic to_q/to_k/to_v fusion path")
+
+ def test_attention_processor_count_mismatch_raises_error(self):
+ init_dict = self.get_init_dict()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ if not hasattr(model, "set_attn_processor"):
+ pytest.skip("Model does not support setting attention processors.")
+
+ current_processors = model.attn_processors
+ if len(current_processors) <= 1:
+ pytest.skip("Lens test config exposes only one attention processor.")
+
+ return super().test_attention_processor_count_mismatch_raises_error()
diff --git a/tests/pipelines/lens/__init__.py b/tests/pipelines/lens/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/lens/test_pipeline_lens.py b/tests/pipelines/lens/test_pipeline_lens.py
new file mode 100644
index 000000000000..ca15de857ab5
--- /dev/null
+++ b/tests/pipelines/lens/test_pipeline_lens.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2025 Microsoft and HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from tokenizers import Tokenizer
+from tokenizers.decoders import WordPiece as WordPieceDecoder
+from tokenizers.models import WordPiece
+from tokenizers.pre_tokenizers import Whitespace
+from transformers import GptOssConfig, PreTrainedTokenizerFast
+
+from diffusers import (
+ AutoencoderKLFlux2,
+ FlowMatchEulerDiscreteScheduler,
+ LensPipeline,
+ LensTransformer2DModel,
+)
+from diffusers.pipelines.lens.pipeline_lens import LensGptOssEncoder
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+def _build_dummy_gptoss_tokenizer():
+ vocab = {
+ "[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "<|return|>": 4,
+ "hello": 5, "world": 6, "a": 7, "cat": 8, "the": 9, "is": 10,
+ "on": 11, "mat": 12, "painting": 13, "of": 14, "squirrel": 15,
+ "eating": 16, "burger": 17, "Describe": 18, "image": 19,
+ }
+ tok = Tokenizer(WordPiece(vocab, unk_token="[UNK]"))
+ tok.pre_tokenizer = Whitespace()
+ tok.decoder = WordPieceDecoder()
+ fast = PreTrainedTokenizerFast(tokenizer_object=tok)
+ fast.pad_token = "[PAD]"
+ fast.eos_token = "[SEP]"
+ fast.chat_template = (
+ "{% for message in messages %}"
+ "{{ message.content }}"
+ "{% if not loop.last %} {% endif %}"
+ "{% endfor %}"
+ )
+ return fast
+
+
+class LensPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LensPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ required_optional_params = PipelineTesterMixin.required_optional_params
+ test_layerwise_casting = True
+ test_group_offloading = True
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = LensTransformer2DModel(
+ patch_size=2,
+ in_channels=64,
+ out_channels=16,
+ num_layers=num_layers,
+ attention_head_dim=20,
+ num_attention_heads=1,
+ inner_dim=20,
+ enc_hidden_dim=32,
+ axes_dims_rope=(4, 8, 8),
+ gate_mlp=True,
+ rms_norm=True,
+ multi_layer_encoder_feature=True,
+ selected_layer_index=(0,),
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKLFlux2(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ latent_channels=16,
+ down_block_types=("DownEncoderBlock2D",),
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(32,),
+ layers_per_block=1,
+ norm_num_groups=32,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ )
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ config = GptOssConfig(
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ intermediate_size=64,
+ head_dim=16,
+ layer_types=["full_attention", "full_attention"],
+ vocab_size=1000,
+ )
+ text_encoder = LensGptOssEncoder(config)
+ tokenizer = _build_dummy_gptoss_tokenizer()
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder.eval(),
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 16,
+ "width": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 2, 2))
+ expected_image = torch.randn(3, 2, 2)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ for tensor_name, tensor_value in callback_kwargs.items():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+ for tensor_name, tensor_value in callback_kwargs.items():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+ self.assertIsNotNone(output)
+
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ self.assertIsNotNone(output)