Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
8bb47d8
Merge branch 'main' into txt_seq_lens
kashif Dec 19, 2025
b7c288a
removed
kashif Dec 20, 2025
4700b7f
Merge branch 'main' into txt_seq_lens
kashif Dec 20, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix compile
  • Loading branch information
kashif committed Dec 17, 2025
commit 2ef38e2c3457c2bc4dcd7f5d87d60c747985fb25
62 changes: 30 additions & 32 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,7 @@ def compute_text_seq_len_from_mask(
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))

# For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always)
# Keep as tensor to avoid graph breaks in torch.compile
rope_text_seq_len = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)

return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask
return text_seq_len, per_sample_len, encoder_hidden_states_mask


class QwenTimestepProjEmbeddings(nn.Module):
Expand Down Expand Up @@ -271,10 +266,6 @@ def forward(
if max_txt_seq_len is None:
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")

# Move to device unconditionally to avoid graph breaks in torch.compile
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

# Validate batch inference with variable-sized images
if isinstance(video_fhw, list) and len(video_fhw) > 1:
# Check if all instances have the same size
Expand All @@ -297,25 +288,29 @@ def forward(
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)

if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)

txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...]
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=128)
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
Expand Down Expand Up @@ -384,10 +379,6 @@ def forward(
device: (`torch.device`, *optional*):
The device on which to perform the RoPE computation.
"""
# Move to device unconditionally to avoid graph breaks in torch.compile
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

# Validate batch inference with variable-sized images
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
if isinstance(video_fhw, list) and len(video_fhw) > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @naykun good for you?

Expand All @@ -412,11 +403,10 @@ def forward(
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
video_freq = self._compute_condition_freqs(frame, height, width, device)
vid_freqs.append(video_freq)

if self.scale_rope:
Expand All @@ -425,16 +415,21 @@ def forward(
max_vid_index = max(height, width, max_vid_index)

max_vid_index = max(max_vid_index, layer_num)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...]
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
Expand All @@ -450,10 +445,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
return freqs.clone().contiguous()

@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs

freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
Expand Down Expand Up @@ -911,8 +909,8 @@ def forward(
"txt_seq_lens",
"0.37.0",
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. "
"Please use `txt_seq_len` instead (singular, not plural). "
"The new parameter accepts a single int or tensor value instead of a list.",
"Please use `encoder_hidden_states_mask` instead. "
"The mask-based approach is more flexible and supports variable-length sequences.",
standard_warn=False,
)
if attention_kwargs is not None:
Expand Down
95 changes: 91 additions & 4 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ def test_infers_text_seq_len_from_mask(self):
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)

# Verify rope_text_seq_len is returned as a tensor (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, torch.Tensor)
self.assertEqual(rope_text_seq_len.ndim, 0) # Should be scalar tensor
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)

# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
Expand All @@ -116,7 +115,7 @@ def test_infers_text_seq_len_from_mask(self):
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values

# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(int(rope_text_seq_len.item()), inputs["encoder_hidden_states"].shape[1])
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])

# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
Expand All @@ -142,6 +141,7 @@ def test_infers_text_seq_len_from_mask(self):
inputs["encoder_hidden_states"], None
)
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)

Expand All @@ -162,6 +162,7 @@ def test_non_contiguous_attention_mask(self):
)
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)

inputs["encoder_hidden_states_mask"] = normalized_mask
Expand All @@ -171,6 +172,92 @@ def test_non_contiguous_attention_mask(self):

self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])

def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]

# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens

# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)

# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)

# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])

def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 16,
"num_layers": 2,
"attention_head_dim": 128,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"use_layer3d_rope": True, # Enable layered RoPE
}

model = self.model_class(**init_dict).to(torch_device)

# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope

self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)

# Test single generation with layered structure
batch_size = 1
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4

# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)

# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens

timestep = torch.tensor([1.0]).to(torch_device)

# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
]
]

with torch.no_grad():
output = model(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
)

self.assertEqual(output.sample.shape[1], hidden_states.shape[1])


class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
Expand Down