Skip to content

Commit 9b570cd

Browse files
Niels RoggeNiels Rogge
authored andcommitted
Improve initialization
1 parent e26b8c1 commit 9b570cd

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/transformers/models/deit/configuration_deit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class DeiTConfig(PretrainedConfig):
6666
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
6767
cls_token_initializer_range (`float`, *optional*, defaults to 1e-6):
6868
The standard deviation of the truncated_normal_initializer for initializing the `cls_token` parameter.
69+
distillation_token_initializer_range (`float`, *optional*, defaults to 0.02):
70+
The standard deviation of the truncated_normal_initializer for initializing the `distillation_token`
71+
parameter.
6972
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
7073
The epsilon used by the layer normalization layers.
7174
image_size (`int`, *optional*, defaults to `224`):
@@ -95,7 +98,6 @@ class DeiTConfig(PretrainedConfig):
9598
```"""
9699
model_type = "deit"
97100

98-
# Copied from transformers.models.vit.configuration_vit.ViTConfig.__init__
99101
def __init__(
100102
self,
101103
hidden_size=768,
@@ -106,6 +108,7 @@ def __init__(
106108
hidden_dropout_prob=0.0,
107109
attention_probs_dropout_prob=0.0,
108110
cls_token_initializer_range=1e-6,
111+
distillation_token_initializer_range=0.02,
109112
initializer_range=0.02,
110113
layer_norm_eps=1e-12,
111114
image_size=224,
@@ -125,6 +128,7 @@ def __init__(
125128
self.hidden_dropout_prob = hidden_dropout_prob
126129
self.attention_probs_dropout_prob = attention_probs_dropout_prob
127130
self.cls_token_initializer_range = cls_token_initializer_range
131+
self.distillation_token_initializer_range = distillation_token_initializer_range
128132
self.initializer_range = initializer_range
129133
self.layer_norm_eps = layer_norm_eps
130134
self.image_size = image_size

src/transformers/models/deit/modeling_deit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def custom_forward(*inputs):
387387
)
388388

389389

390-
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing
391390
class DeiTPreTrainedModel(PreTrainedModel):
392391
"""
393392
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -415,6 +414,9 @@ def _init_weights(self, module) -> None:
415414
module.weight.data.fill_(1.0)
416415
elif isinstance(module, DeiTEmbeddings):
417416
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
417+
torch.nn.init.trunc_normal_(
418+
module.distillation_token.data, std=self.config.distillation_token_initializer_range
419+
)
418420
if module.mask_token is not None:
419421
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
420422
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

0 commit comments

Comments
 (0)