- Notifications
You must be signed in to change notification settings - Fork 31.4k
dalle mega #18152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
dalle mega #18152
Conversation
| heads. | ||
| """ | ||
| last_hidden_state: torch.FloatTensor = None | ||
| last_hidden_state_unconditional: Optional[Tuple[torch.FloatTensor]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last_hidden_state_unconditional is the unconditional encoder output required for superconditioning(guidance).
Added it here, so that it can be easily passed from encoder to decoder.
| if do_superconditioning: | ||
| input_ids_uncond = torch.ones(input_shape, dtype=torch.long, device=inputs_embeds.device) * self.config.pad_token_id | ||
| attention_mask_uncond = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device) | ||
| | ||
| inputs_embeds_unconditional = self.embed_tokens(input_ids_uncond) | ||
| | ||
| # concatenate the embeddings of the conditioned and unconditioned inputs | ||
| inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_unconditional], dim=0) | ||
| | ||
| # concatenate the attention masks of the conditioned and unconditioned inputs | ||
| # if attention_mask is None, create an all-ones mask | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones(input_shape, dtype=torch.long, device=inputs_embeds.device) | ||
| attention_mask = torch.cat([attention_mask, attention_mask_uncond], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we extend the inputs_embeds with inputs_embeds_unconditional to get the unconditional hidden states.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me!
| # filter out the unconditional hidden states from encoder_states | ||
| if do_superconditioning and output_hidden_states: | ||
| encoder_states = (state.chunk(2)[0] for state in encoder_states) | ||
| | ||
| # filter out the unconditional attentions from all_attentions | ||
| if do_superconditioning and output_attentions: | ||
| all_attentions = (attn.chunk(2)[0] for attn in all_attentions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
filter out the unconditional hidden_states and attentions since we don't want to return those.
| hidden_states_uncond = None | ||
| if do_superconditioning: | ||
| hidden_states, hidden_states_uncond = hidden_states.chunk(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate the last_hidden_states for conditional and unconditional inputs.
| if do_superconditioning: | ||
| encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_uncond], dim=0) | ||
| input_ids = input_ids.repeat(2, 1) | ||
| attention_mask = attention_mask.repeat(2, 1) | ||
| |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The encoder_hidden_states_uncond will be passed from. DalleMegaModel. We concatenate those two here
and repeat the decoder inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Maybe we can also raise a nice error message if the encoder_hidden_states_uncond are in the wrong format or None but do_superconditioning is True?
| input_ids=decoder_input_ids, | ||
| attention_mask=decoder_attention_mask, | ||
| encoder_hidden_states=encoder_outputs[0], | ||
| encodert_hidden_states_unconditional=encoder_outputs[1] if do_superconditioning else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass the unconditional encoder hidden states to decoder.
| do_superconditioning = superconditioning_scale > 1 and (encoder_outputs.last_hidden_state_unconditional is not None) and (not self.training) | ||
| if do_superconditioning: | ||
| lm_logits, lm_logits_uncond = lm_logits.chunk(2) | ||
| lm_logits = lm_logits + superconditioning_scale * (lm_logits - lm_logits_uncond) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do the actual superconditioning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Just out of curiosity, the _scale it cannot be between 0 and 1?
| "decoder_head_mask": decoder_head_mask, | ||
| "cross_attn_head_mask": cross_attn_head_mask, | ||
| "use_cache": use_cache, # change this to avoid caching (presumably for debugging) | ||
| "superconditioning_scale": superconditioning_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sure that we always pass superconditioning_scale to model forward.
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
| superconditioning_scale = superconditioning_scale if superconditioning_scale is not None else self.config.superconditioning_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be in favor of not adding the superconditioning_scale parameter to the config as it's something the user would want to change during forward and a specific value is not necessarily attached to a trained model
| ) | ||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
| superconditioning_scale = superconditioning_scale if superconditioning_scale is not None else self.config.superconditioning_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here - let's not have a fallback to the config here
| def set_output_embeddings(self, new_embeddings): | ||
| self.lm_head = new_embeddings | ||
| | ||
| def tie_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this here? can't this just use the standard tie_weights function in modeling_utils.py?
| def prepare_inputs_for_generation( | ||
| self, | ||
| decoder_input_ids, | ||
| superconditioning_scale=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe not have it at the 2nd position but a bit later after the tensors? Also a type hint :int would be nice here
| @patil-suraj - I can take over the PR if you want :-) |
| This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
This PR adds the DalleMega model from dalle-mini for text-2-image generation.
The VQGAN model required for converting the tokens to image is in this PR #18150
samplemethod for classifier-free guidance.