Skip to content

Conversation

@patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Jul 15, 2022

What does this PR do?

This PR adds the DalleMega model from dalle-mini for text-2-image generation.
The VQGAN model required for converting the tokens to image is in this PR #18150

  • override the sample method for classifier-free guidance.
  • port and upload weights on the hub
  • add tests
  • add docs
  • boom!
heads.
"""
last_hidden_state: torch.FloatTensor = None
last_hidden_state_unconditional: Optional[Tuple[torch.FloatTensor]] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last_hidden_state_unconditional is the unconditional encoder output required for superconditioning(guidance).
Added it here, so that it can be easily passed from encoder to decoder.

Comment on lines 781 to 794
if do_superconditioning:
input_ids_uncond = torch.ones(input_shape, dtype=torch.long, device=inputs_embeds.device) * self.config.pad_token_id
attention_mask_uncond = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)

inputs_embeds_unconditional = self.embed_tokens(input_ids_uncond)

# concatenate the embeddings of the conditioned and unconditioned inputs
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_unconditional], dim=0)

# concatenate the attention masks of the conditioned and unconditioned inputs
# if attention_mask is None, create an all-ones mask
if attention_mask is None:
attention_mask = torch.ones(input_shape, dtype=torch.long, device=inputs_embeds.device)
attention_mask = torch.cat([attention_mask, attention_mask_uncond], dim=0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we extend the inputs_embeds with inputs_embeds_unconditional to get the unconditional hidden states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me!

Comment on lines 855 to 861
# filter out the unconditional hidden states from encoder_states
if do_superconditioning and output_hidden_states:
encoder_states = (state.chunk(2)[0] for state in encoder_states)

# filter out the unconditional attentions from all_attentions
if do_superconditioning and output_attentions:
all_attentions = (attn.chunk(2)[0] for attn in all_attentions)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter out the unconditional hidden_states and attentions since we don't want to return those.

Comment on lines +863 to +865
hidden_states_uncond = None
if do_superconditioning:
hidden_states, hidden_states_uncond = hidden_states.chunk(2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate the last_hidden_states for conditional and unconditional inputs.

Comment on lines 1026 to 1030
if do_superconditioning:
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_uncond], dim=0)
input_ids = input_ids.repeat(2, 1)
attention_mask = attention_mask.repeat(2, 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encoder_hidden_states_uncond will be passed from. DalleMegaModel. We concatenate those two here
and repeat the decoder inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Maybe we can also raise a nice error message if the encoder_hidden_states_uncond are in the wrong format or None but do_superconditioning is True?

input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encodert_hidden_states_unconditional=encoder_outputs[1] if do_superconditioning else None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass the unconditional encoder hidden states to decoder.

Comment on lines 1490 to 1493
do_superconditioning = superconditioning_scale > 1 and (encoder_outputs.last_hidden_state_unconditional is not None) and (not self.training)
if do_superconditioning:
lm_logits, lm_logits_uncond = lm_logits.chunk(2)
lm_logits = lm_logits + superconditioning_scale * (lm_logits - lm_logits_uncond)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the actual superconditioning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Just out of curiosity, the _scale it cannot be between 0 and 1?

"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"superconditioning_scale": superconditioning_scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sure that we always pass superconditioning_scale to model forward.

output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
superconditioning_scale = superconditioning_scale if superconditioning_scale is not None else self.config.superconditioning_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favor of not adding the superconditioning_scale parameter to the config as it's something the user would want to change during forward and a specific value is not necessarily attached to a trained model

)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
superconditioning_scale = superconditioning_scale if superconditioning_scale is not None else self.config.superconditioning_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here - let's not have a fallback to the config here

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def tie_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here? can't this just use the standard tie_weights function in modeling_utils.py?

def prepare_inputs_for_generation(
self,
decoder_input_ids,
superconditioning_scale=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe not have it at the 2nd position but a bit later after the tensors? Also a type hint :int would be nice here

@patrickvonplaten
Copy link
Contributor

@patil-suraj - I can take over the PR if you want :-)

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Sep 26, 2022
@patrickvonplaten patrickvonplaten added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

2 participants