- Notifications
You must be signed in to change notification settings - Fork 6.4k
Flux followup #9074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux followup #9074
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Could you please inform when the related PR for the 'flux transformer block' and 'pipeline' can be completed and merged? We are currently working on the adaptation and training of controlnet and other related plugins on the old version. After this PR is completed, we will reorganize the code to adapt to the new coding style, and then we will submit a new PR for controlnet and ipadapter. prompt = "A girl with green long hair. she is wearing a yellow suit. half body, background is sky and cloud. anime style" |
@wangqixun let's not wait for this refactor to be done for the PR! we can refactor the ip-adapter and controlnet together once the PR is in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Superb!
So, IIUC, the batched txt_ids
and img_ids
are not at a necessity now because of how we're doing the RoPE in the refactored class (FluxPosEmbed
)? Or are there any additional differences to be aware of?
| ||
def __init__(self): | ||
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." | ||
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe deprecate it earlier? Not a strong opinion, though.
return x_out.type_as(x) | ||
| ||
| ||
class FluxPosEmbed(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a reference to the original BFL inference code?
sin_out = [] | ||
pos = ids.squeeze().float().cpu().numpy() | ||
is_mps = ids.device.type == "mps" | ||
freqs_dtype = torch.float32 if is_mps else torch.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul the results for flux are identical with this refactor
the only other difference is here, where we downcast the dtype for mps see #9133 for more details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aye, thanks!
unscale_lora_layers(self.text_encoder_2, lora_scale) | ||
| ||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype | ||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍🏽
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽
* refactor rotary embeds * adding jsmidt as co-author of this PR for #9133 --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Joseph Smidt <josephsmidt@gmail.com>
to-do
EmbedND
,apply_rope
,rope
;FluxPosEmbed
that uses diffusers' existingget_1d_rotary_pos_embed
method, which is already used by hunyuan dit, lumina and stable audioimg_ids
andtxt_ids
: currently, these are 3d tensors with a batch dimension; I changed them to 2d since these are just positional ids that are used to encode rotary embedding, so we do not need to add a batch dimension here, andget_1d_rotary_pos_embed
method does not accept batched positional ids so we remove it here to be consistent across the library - note that this is a breaking change, so I made sure to deprecate it, also add a test to make sure the previous inputs will still workFluxSingleAttnProcessor2_0
)