- Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix QwenImage txt_seq_lens handling #12702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kashif wants to merge 50 commits into huggingface:main Choose a base branch from kashif:txt_seq_lens
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline, and old review comments may become outdated.
+503 −163
Open
Changes from 1 commit
Commits
Show all changes
50 commits Select commit Hold shift + click to select a range
b547fcf Fix QwenImage txt_seq_lens handling
kashif 72a80c6 formatting
kashif 88cee8b formatting
kashif ac5ac24 remove txt_seq_lens and use bool mask
kashif 0477526 Merge branch 'main' into txt_seq_lens
kashif 18efdde use compute_text_seq_len_from_mask
kashif 6a549d4 add seq_lens to dispatch_attention_fn
kashif 2d424e0 use joint_seq_lens
kashif 30b5f98 remove unused index_block
kashif 588dc04 Merge branch 'main' into txt_seq_lens
kashif f1c2d99 WIP: Remove seq_lens parameter and use mask-based approach
kashif ec52417 Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif beeb020 fix formatting
kashif 5c6f8e3 undo sage changes
kashif 5d434f6 xformers support
kashif 71ba603 hub fix
kashif babf490 Merge branch 'main' into txt_seq_lens
kashif afad335 fix torch compile issues
kashif 2d5ab16 Merge branch 'main' into txt_seq_lens
sayakpaul c78a1e9 fix tests
kashif d6d4b1d use _prepare_attn_mask_native
kashif e999b76 proper deprecation notice
kashif 8115f0b add deprecate to txt_seq_lens
kashif 3b1510c Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif 3676d8e Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif 9ed0ffd Only create the mask if there's actual padding
kashif abec461 Merge branch 'main' into txt_seq_lens
kashif e26e7b3 fix order of docstrings
kashif 59e3882 Adds performance benchmarks and optimization details for QwenImage
cdutr 0cb2138 Merge branch 'main' into txt_seq_lens
kashif 60bd454 rope_text_seq_len = text_seq_len
kashif a5abbb8 rename to max_txt_seq_len
kashif 8415c57 Merge branch 'main' into txt_seq_lens
kashif afff5b7 Merge branch 'main' into txt_seq_lens
kashif 8dc6c3f Merge branch 'main' into txt_seq_lens
kashif 22cb03d removed deprecated args
kashif 125a3a4 undo unrelated change
kashif b5b6342 Updates QwenImage performance documentation
cdutr 61f5265 Updates deprecation warnings for txt_seq_lens parameter
cdutr 2ef38e2 fix compile
kashif 270c63f Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif 35efa06 formatting
kashif 50c4815 fix compile tests
kashif c88bc06 Merge branch 'main' into txt_seq_lens
kashif 1433783 rename helper
kashif 8de799c remove duplicate
kashif fc93747 smaller values
kashif 8bb47d8 Merge branch 'main' into txt_seq_lens
kashif b7c288a removed
kashif 4700b7f Merge branch 'main' into txt_seq_lens
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix compile
- Loading branch information
commit 2ef38e2c3457c2bc4dcd7f5d87d60c747985fb25
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -165,12 +165,7 @@ def compute_text_seq_len_from_mask( | |
| active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) | ||
| has_active = encoder_hidden_states_mask.any(dim=1) | ||
| per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) | ||
| | ||
| # For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always) | ||
| # Keep as tensor to avoid graph breaks in torch.compile | ||
| rope_text_seq_len = torch.tensor(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) | ||
| | ||
| return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask | ||
| return text_seq_len, per_sample_len, encoder_hidden_states_mask | ||
| | ||
| | ||
| class QwenTimestepProjEmbeddings(nn.Module): | ||
| | @@ -271,10 +266,6 @@ def forward( | |
| if max_txt_seq_len is None: | ||
| raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") | ||
| | ||
| # Move to device unconditionally to avoid graph breaks in torch.compile | ||
| self.pos_freqs = self.pos_freqs.to(device) | ||
| self.neg_freqs = self.neg_freqs.to(device) | ||
| | ||
| # Validate batch inference with variable-sized images | ||
| if isinstance(video_fhw, list) and len(video_fhw) > 1: | ||
| # Check if all instances have the same size | ||
| | @@ -297,25 +288,29 @@ def forward( | |
| for idx, fhw in enumerate(video_fhw): | ||
| frame, height, width = fhw | ||
| # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs | ||
| video_freq = self._compute_video_freqs(frame, height, width, idx) | ||
| video_freq = video_freq.to(device) | ||
| video_freq = self._compute_video_freqs(frame, height, width, idx, device) | ||
| vid_freqs.append(video_freq) | ||
| | ||
| if self.scale_rope: | ||
| max_vid_index = max(height // 2, width // 2, max_vid_index) | ||
| else: | ||
| max_vid_index = max(height, width, max_vid_index) | ||
| | ||
| txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] | ||
| max_txt_seq_len_int = int(max_txt_seq_len) | ||
| # Create device-specific copy for text freqs without modifying self.pos_freqs | ||
| txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] | ||
| vid_freqs = torch.cat(vid_freqs, dim=0) | ||
| | ||
| return vid_freqs, txt_freqs | ||
| | ||
| @functools.lru_cache(maxsize=128) | ||
| def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: | ||
| def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None) -> torch.Tensor: | ||
| seq_lens = frame * height * width | ||
| freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs | ||
| neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs | ||
| | ||
| freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| | ||
| freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
| if self.scale_rope: | ||
| | @@ -384,10 +379,6 @@ def forward( | |
| device: (`torch.device`, *optional*): | ||
| The device on which to perform the RoPE computation. | ||
| """ | ||
| # Move to device unconditionally to avoid graph breaks in torch.compile | ||
| self.pos_freqs = self.pos_freqs.to(device) | ||
| self.neg_freqs = self.neg_freqs.to(device) | ||
| | ||
| # Validate batch inference with variable-sized images | ||
| # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers | ||
| if isinstance(video_fhw, list) and len(video_fhw) > 1: | ||
| Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cc: @naykun good for you? | ||
| | @@ -412,11 +403,10 @@ def forward( | |
| for idx, fhw in enumerate(video_fhw): | ||
| frame, height, width = fhw | ||
| if idx != layer_num: | ||
| video_freq = self._compute_video_freqs(frame, height, width, idx) | ||
| video_freq = self._compute_video_freqs(frame, height, width, idx, device) | ||
| else: | ||
| ### For the condition image, we set the layer index to -1 | ||
| video_freq = self._compute_condition_freqs(frame, height, width) | ||
| video_freq = video_freq.to(device) | ||
| video_freq = self._compute_condition_freqs(frame, height, width, device) | ||
| vid_freqs.append(video_freq) | ||
| | ||
| if self.scale_rope: | ||
| | @@ -425,16 +415,21 @@ def forward( | |
| max_vid_index = max(height, width, max_vid_index) | ||
| | ||
| max_vid_index = max(max_vid_index, layer_num) | ||
| txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_txt_seq_len, ...] | ||
| max_txt_seq_len_int = int(max_txt_seq_len) | ||
| # Create device-specific copy for text freqs without modifying self.pos_freqs | ||
| txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] | ||
| vid_freqs = torch.cat(vid_freqs, dim=0) | ||
| | ||
| return vid_freqs, txt_freqs | ||
| | ||
| @functools.lru_cache(maxsize=None) | ||
| def _compute_video_freqs(self, frame, height, width, idx=0): | ||
| def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): | ||
| seq_lens = frame * height * width | ||
| freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs | ||
| neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs | ||
| | ||
| freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| | ||
| freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
| if self.scale_rope: | ||
| | @@ -450,10 +445,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0): | |
| return freqs.clone().contiguous() | ||
| | ||
| @functools.lru_cache(maxsize=None) | ||
| def _compute_condition_freqs(self, frame, height, width): | ||
| def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): | ||
| seq_lens = frame * height * width | ||
| freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs | ||
| neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs | ||
| | ||
| freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| | ||
| freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
| if self.scale_rope: | ||
| | @@ -911,8 +909,8 @@ def forward( | |
| "txt_seq_lens", | ||
| "0.37.0", | ||
| "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. " | ||
| "Please use `txt_seq_len` instead (singular, not plural). " | ||
| "The new parameter accepts a single int or tensor value instead of a list.", | ||
| "Please use `encoder_hidden_states_mask` instead. " | ||
| "The mask-based approach is more flexible and supports variable-length sequences.", | ||
| standard_warn=False, | ||
| ) | ||
| if attention_kwargs is not None: | ||
| | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.