Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
8bb47d8
Merge branch 'main' into txt_seq_lens
kashif Dec 19, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def forward(
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
txt_seq_lens (`List[int]`, *optional*):
Optional text sequence lengths. If omitted, or shorter than the encoder hidden states length, the model
derives the length from the encoder hidden states (or their mask).
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand Down Expand Up @@ -244,7 +247,15 @@ def forward(

temb = self.time_text_embed(timestep, hidden_states)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
if txt_seq_lens is not None:
if len(txt_seq_lens) != batch_size:
raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.")
text_seq_len = max(text_seq_len, max(txt_seq_lens))
elif encoder_hidden_states_mask is not None:
text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item()))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if the attention mask is in the form of [True, True, True, ..., False, False, False]. While this is the case in the most common use case of text attention masks, it doesn't have to be the case.

If the mask is [True, False, True, False, True, False], self.pos_embed receives an incorrect sequence length


image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)

timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
Expand Down
45 changes: 39 additions & 6 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ def rope_params(self, index, dim, theta=10000):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
txt_seq_lens: List[int],
txt_seq_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
txt_seq_len (`int`):
The length of the text sequence. This should match the encoder hidden states length.
device: (`torch.device`):
The device on which to perform the RoPE computation.
"""
Expand All @@ -232,8 +232,7 @@ def forward(
else:
max_vid_index = max(height, width, max_vid_index)

max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_seq_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs
Expand Down Expand Up @@ -330,6 +329,29 @@ def __call__(
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)

# If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask.
if encoder_hidden_states_mask is not None and attention_mask is None:
batch_size, image_seq_len = hidden_states.shape[:2]
text_seq_len = encoder_hidden_states.shape[1]

if encoder_hidden_states_mask.shape[0] != batch_size:
raise ValueError(
f"`encoder_hidden_states_mask` batch size ({encoder_hidden_states_mask.shape[0]}) "
f"must match hidden_states batch size ({batch_size})."
)
if encoder_hidden_states_mask.shape[1] != text_seq_len:
raise ValueError(
f"`encoder_hidden_states_mask` sequence length ({encoder_hidden_states_mask.shape[1]}) "
f"must match encoder_hidden_states sequence length ({text_seq_len})."
)

text_attention_mask = encoder_hidden_states_mask.to(dtype=torch.bool, device=hidden_states.device)
image_attention_mask = torch.ones(
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
)
joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
attention_mask = joint_attention_mask[:, None, None, :]

# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
Expand Down Expand Up @@ -588,6 +610,9 @@ def forward(
Mask of the input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
txt_seq_lens (`List[int]`, *optional*):
Optional text sequence lengths. If not provided, or if any provided values are shorter than the encoder
hidden states length, the model falls back to the encoder hidden states length.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand Down Expand Up @@ -621,6 +646,14 @@ def forward(
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

batch_size, text_seq_len = encoder_hidden_states.shape[:2]
if txt_seq_lens is not None:
if len(txt_seq_lens) != batch_size:
raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.")
text_seq_len = max(text_seq_len, max(txt_seq_lens))
elif encoder_hidden_states_mask is not None:
text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item()))

if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000

Expand All @@ -630,7 +663,7 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states)
)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
41 changes: 0 additions & 41 deletions src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,6 @@ def intermediate_outputs(self) -> List[OutputParam]:
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]

def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
Expand All @@ -551,14 +539,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
)
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)

self.set_block_state(state, block_state)

Expand Down Expand Up @@ -592,18 +572,6 @@ def intermediate_outputs(self) -> List[OutputParam]:
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]

def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
Expand All @@ -626,15 +594,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
]
] * block_state.batch_size

block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)

self.set_block_state(state, block_state)

return components, state
Expand Down
11 changes: 1 addition & 10 deletions src/diffusers/modular_pipelines/qwenimage/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def inputs(self) -> List[InputParam]:
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs for the denoiser. "
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
"It should contain prompt_embeds/negative_prompt_embeds."
),
),
]
Expand All @@ -176,7 +176,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
img_shapes=block_state.img_shapes,
encoder_hidden_states=block_state.prompt_embeds,
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
txt_seq_lens=block_state.txt_seq_lens,
return_dict=False,
)

Expand Down Expand Up @@ -247,10 +246,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}

components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
Expand Down Expand Up @@ -345,10 +340,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
getattr(block_state, "prompt_embeds_mask", None),
getattr(block_state, "negative_prompt_embeds_mask", None),
),
"txt_seq_lens": (
getattr(block_state, "txt_seq_lens", None),
getattr(block_state, "negative_txt_seq_lens", None),
),
}

components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
Expand Down
7 changes: 0 additions & 7 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)

# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -695,7 +690,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand All @@ -709,7 +703,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)

Expand All @@ -920,7 +919,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand All @@ -935,7 +933,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)

Expand All @@ -863,7 +862,6 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand All @@ -878,7 +876,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand Down
7 changes: 0 additions & 7 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)

# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -821,7 +816,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand All @@ -836,7 +830,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,11 +1008,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand All @@ -1035,7 +1030,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand All @@ -1050,7 +1044,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,6 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)

# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -805,7 +800,6 @@ def __call__(
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand All @@ -820,7 +814,6 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Expand Down
Loading