Skip to content

Commit be67d33

Browse files
Niels RoggeNiels Rogge
authored andcommitted
Fix beit init
1 parent 9b570cd commit be67d33

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/transformers/models/beit/configuration_beit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class BeitConfig(PretrainedConfig):
5959
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
6060
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
6161
The dropout ratio for the attention probabilities.
62+
cls_token_initializer_range (`float`, *optional*, defaults to 0.02):
63+
The standard deviation of the truncated_normal_initializer for initializing the `cls_token` parameter.
6264
initializer_range (`float`, *optional*, defaults to 0.02):
6365
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
6466
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
@@ -127,6 +129,7 @@ def __init__(
127129
hidden_act="gelu",
128130
hidden_dropout_prob=0.0,
129131
attention_probs_dropout_prob=0.0,
132+
cls_token_initializer_range=0.02,
130133
initializer_range=0.02,
131134
layer_norm_eps=1e-12,
132135
image_size=224,
@@ -159,6 +162,7 @@ def __init__(
159162
self.hidden_act = hidden_act
160163
self.hidden_dropout_prob = hidden_dropout_prob
161164
self.attention_probs_dropout_prob = attention_probs_dropout_prob
165+
self.cls_token_initializer_range = cls_token_initializer_range
162166
self.initializer_range = initializer_range
163167
self.layer_norm_eps = layer_norm_eps
164168

src/transformers/models/beit/modeling_beit.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -560,19 +560,17 @@ class BeitPreTrainedModel(PreTrainedModel):
560560

561561
def _init_weights(self, module):
562562
"""Initialize the weights"""
563-
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
564-
# Slightly different from the TF version which uses truncated_normal for initialization
565-
# cf https://github.com/pytorch/pytorch/pull/5617
566-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
567-
if module.bias is not None:
568-
module.bias.data.zero_()
569-
elif isinstance(module, nn.Embedding):
570-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
571-
if module.padding_idx is not None:
572-
module.weight.data[module.padding_idx].zero_()
563+
if isinstance(module, nn.Linear):
564+
torch.nn.init.trunc_normal_(module.data, std=self.config.initializer_range)
573565
elif isinstance(module, nn.LayerNorm):
574566
module.bias.data.zero_()
575567
module.weight.data.fill_(1.0)
568+
elif isinstance(module, BeitEmbeddings):
569+
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
570+
if module.mask_token is not None:
571+
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
572+
if module.position_embeddings is not None:
573+
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)
576574

577575
def _set_gradient_checkpointing(self, module, value=False):
578576
if isinstance(module, BeitEncoder):

0 commit comments

Comments
 (0)