- Notifications
You must be signed in to change notification settings - Fork 6.4k
Flux followup #9074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Flux followup #9074
Changes from all commits
Commits
Show all changes
20 commits Select commit Hold shift + click to select a range
efc7ed9
edit
yiyixuxu 9b8f8c7
refactor rotary embeds
yiyixuxu 1b4d1c5
Merge branch 'main' into flux-followup
yiyixuxu 1887bda
Update src/diffusers/models/transformers/transformer_flux.py
yiyixuxu a9cdfcc
fix
yiyixuxu de66c58
remove the batch dimension in ids
yiyixuxu abad854
keep transformer timesteps input same
yiyixuxu 463b910
add freqs_dtype, allow torch.float64 and make adjustment for mps device
yiyixuxu f23cb1b
deprecate flux single attn processor
yiyixuxu 568884a
Merge branch 'main' into flux-followup
yiyixuxu ab3a550
deprecate 2d ids inputs to flux transformer
yiyixuxu 079cb33
Merge branch 'flux-followup' of github.com:huggingface/diffusers into…
yiyixuxu 4161d93
use FluxPosEmbed in flux controlnet too
yiyixuxu 89e0ccc
apply same change to controlnet
yiyixuxu 0ff2266
add a test for deprecated flux tranformers inputs: txt and img ids as…
yiyixuxu 40e94e0
up
yiyixuxu 293bcd8
Merge branch 'main' into flux-followup
sayakpaul 72d1cf0
adding jsmidt as co-author of this PR for https://github.com/huggingf…
jsmidt f0301b2
Merge branch 'flux-followup' of github.com:huggingface/diffusers into…
yiyixuxu 95b0a55
Merge branch 'main' into flux-followup
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
| @@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed( | |
linear_factor=1.0, | ||
ntk_factor=1.0, | ||
repeat_interleave_real=True, | ||
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) | ||
): | ||
""" | ||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
| @@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed( | |
repeat_interleave_real (`bool`, *optional*, defaults to `True`): | ||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. | ||
Otherwise, they are concateanted with themselves. | ||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): | ||
the dtype of the frequency tensor. | ||
Returns: | ||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] | ||
""" | ||
| @@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed( | |
if isinstance(pos, int): | ||
pos = np.arange(pos) | ||
theta = theta * ntk_factor | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | ||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | ||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] | ||
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] | ||
if use_real and repeat_interleave_real: | ||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | ||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | ||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] | ||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] | ||
return freqs_cos, freqs_sin | ||
elif use_real: | ||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] | ||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] | ||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] | ||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] | ||
return freqs_cos, freqs_sin | ||
else: | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] | ||
return freqs_cis | ||
| ||
| ||
| @@ -540,6 +543,31 @@ def apply_rotary_emb( | |
return x_out.type_as(x) | ||
| ||
| ||
class FluxPosEmbed(nn.Module): | ||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 | ||
def __init__(self, theta: int, axes_dim: List[int]): | ||
super().__init__() | ||
self.theta = theta | ||
self.axes_dim = axes_dim | ||
| ||
def forward(self, ids: torch.Tensor) -> torch.Tensor: | ||
n_axes = ids.shape[-1] | ||
cos_out = [] | ||
sin_out = [] | ||
pos = ids.squeeze().float().cpu().numpy() | ||
is_mps = ids.device.type == "mps" | ||
freqs_dtype = torch.float32 if is_mps else torch.float64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul the results for flux are identical with this refactor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aye, thanks! | ||
for i in range(n_axes): | ||
cos, sin = get_1d_rotary_pos_embed( | ||
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype | ||
) | ||
cos_out.append(cos) | ||
sin_out.append(sin) | ||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) | ||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) | ||
return freqs_cos, freqs_sin | ||
| ||
| ||
class TimestepEmbedding(nn.Module): | ||
def __init__( | ||
self, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a reference to the original BFL inference code?