Skip to content
7 changes: 5 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
this is the only code change from #9321

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): We could add a comment on emphasizing the impact of creating torch.arange() on device so as to prevent syncs. This will be helpful references for us going forward. What do you think?

We could do similar for the changes introduced in https://github.com/huggingface/diffusers/pull/9307/files as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this did not cause the sync, though

1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
Expand Down
Loading