- Notifications
You must be signed in to change notification settings - Fork 6.5k
Flux followup #9074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux followup #9074
Changes from 17 commits
efc7ed9 9b8f8c7 1b4d1c5 1887bda a9cdfcc de66c58 abad854 463b910 f23cb1b 568884a ab3a550 079cb33 4161d93 89e0ccc 0ff2266 40e94e0 293bcd8 72d1cf0 f0301b2 95b0a55 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed( | |
| linear_factor=1.0, | ||
| ntk_factor=1.0, | ||
| repeat_interleave_real=True, | ||
| freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) | ||
| ): | ||
| """ | ||
| Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
| | @@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed( | |
| repeat_interleave_real (`bool`, *optional*, defaults to `True`): | ||
| If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. | ||
| Otherwise, they are concateanted with themselves. | ||
| freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): | ||
| the dtype of the frequency tensor. | ||
| Returns: | ||
| `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] | ||
| """ | ||
| | @@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed( | |
| if isinstance(pos, int): | ||
| pos = np.arange(pos) | ||
| theta = theta * ntk_factor | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | ||
| t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | ||
| freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] | ||
| freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] | ||
| if use_real and repeat_interleave_real: | ||
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | ||
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | ||
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] | ||
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] | ||
| return freqs_cos, freqs_sin | ||
| elif use_real: | ||
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] | ||
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] | ||
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] | ||
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] | ||
| return freqs_cos, freqs_sin | ||
| else: | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] | ||
| return freqs_cis | ||
| | ||
| | ||
| | @@ -540,6 +543,30 @@ def apply_rotary_emb( | |
| return x_out.type_as(x) | ||
| | ||
| | ||
| class FluxPosEmbed(nn.Module): | ||
| Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a reference to the original BFL inference code? | ||
| def __init__(self, theta: int, axes_dim: List[int]): | ||
| super().__init__() | ||
| self.theta = theta | ||
| self.axes_dim = axes_dim | ||
| | ||
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | ||
| n_axes = ids.shape[-1] | ||
| cos_out = [] | ||
| sin_out = [] | ||
| pos = ids.squeeze().float().cpu().numpy() | ||
| is_mps = ids.device.type == "mps" | ||
| freqs_dtype = torch.float32 if is_mps else torch.float64 | ||
| Collaborator Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul the results for flux are identical with this refactor Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aye, thanks! | ||
| for i in range(n_axes): | ||
| cos, sin = get_1d_rotary_pos_embed( | ||
| self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype | ||
| ) | ||
| cos_out.append(cos) | ||
| sin_out.append(sin) | ||
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) | ||
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) | ||
| return freqs_cos, freqs_sin | ||
| | ||
| | ||
| class TimestepEmbedding(nn.Module): | ||
| def __init__( | ||
| self, | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe deprecate it earlier? Not a strong opinion, though.