Skip to content

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 3, 2024

to-do

  • refactor rotary embedding
    • removed Flux-specific methods/classes to use rotary embedding: i.e. EmbedND, apply_rope, rope;
    • created a FluxPosEmbed that uses diffusers' existing get_1d_rotary_pos_embed method, which is already used by hunyuan dit, lumina and stable audio
    • changed the flux transformer inputs img_ids and txt_ids: currently, these are 3d tensors with a batch dimension; I changed them to 2d since these are just positional ids that are used to encode rotary embedding, so we do not need to add a batch dimension here, and get_1d_rotary_pos_embed method does not accept batched positional ids so we remove it here to be consistent across the library - note that this is a breaking change, so I made sure to deprecate it, also add a test to make sure the previous inputs will still work
  • refactor attention processor (combine into one, deprecated FluxSingleAttnProcessor2_0)
# flux unit test for rotary embedding refactor import torch from diffusers import FluxPipeline model_path = "black-forest-labs/FLUX.1-dev" pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power prompt = "A cat holding a sign that says hello world" image = pipe( prompt, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(0), ).images[0] image.save(f"yiyi_test_4_out{branch}.png")
main this PR
yiyi_test_4_out_main yiyi_test_4_out
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@wangqixun
Copy link
Contributor

wangqixun commented Aug 6, 2024

Could you please inform when the related PR for the 'flux transformer block' and 'pipeline' can be completed and merged? We are currently working on the adaptation and training of controlnet and other related plugins on the old version. After this PR is completed, we will reorganize the code to adapt to the new coding style, and then we will submit a new PR for controlnet and ipadapter.

prompt = "A girl with green long hair. she is wearing a yellow suit. half body, background is sky and cloud. anime style"

image_demo

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Aug 6, 2024

@wangqixun let's not wait for this refactor to be done for the PR! we can refactor the ip-adapter and controlnet together once the PR is in

@yiyixuxu yiyixuxu requested review from DN6 and sayakpaul August 18, 2024 09:15
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superb!

So, IIUC, the batched txt_ids and img_ids are not at a necessity now because of how we're doing the RoPE in the refactored class (FluxPosEmbed)? Or are there any additional differences to be aware of?


def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe deprecate it earlier? Not a strong opinion, though.

return x_out.type_as(x)


class FluxPosEmbed(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a reference to the original BFL inference code?

sin_out = []
pos = ids.squeeze().float().cpu().numpy()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul the results for flux are identical with this refactor
the only other difference is here, where we downcast the dtype for mps see #9133 for more details

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye, thanks!

unscale_lora_layers(self.text_encoder_2, lora_scale)

dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍🏽

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

@yiyixuxu yiyixuxu merged commit c291617 into main Aug 21, 2024
@yiyixuxu yiyixuxu deleted the flux-followup branch August 21, 2024 18:45
@yiyixuxu yiyixuxu mentioned this pull request Aug 21, 2024
6 tasks
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* refactor rotary embeds * adding jsmidt as co-author of this PR for #9133 --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Joseph Smidt <josephsmidt@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
6 participants