Skip to content

Conversation

@chrisway613
Copy link

In the decoder position embedding matrix, the size of first dim is the number of patches + 1, as the 1 for ViT's cls_token. But when embedding the position for masked tokens, their indices have not shifted 1, it may confuse with the position of the ViT's cls_token(although MAE do not use cls_token, but this will lead to weak extensibility if we wanna use the cls_token later)

@lucidrains
Copy link
Owner

@chrisway613 Hi Chris! while this is true, i think leaving untrained parameters in the wrapper class isn't elegant. you can always just concat the CLS tokens onto the decoder_pos_emb after you finished training, something like

decoder_cls_token = nn.Parameter(torch.randn(1, decoder_dim)) pos_embs_with_cls_token = torch.cat((decoder_cls_token, self.decoder_pos_emb), dim = 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants