Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,10 @@ class ASTPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True

# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._init_weights with ViT->AST
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
Expand All @@ -396,6 +396,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ASTEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class BeitConfig(PretrainedConfig):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
cls_token_initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the normal initializer for initializing the `cls_token` parameter.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
cls_token_initializer_range=0.02,
initializer_range=0.02,
layer_norm_eps=1e-12,
image_size=224,
Expand Down Expand Up @@ -159,6 +162,7 @@ def __init__(
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.cls_token_initializer_range = cls_token_initializer_range
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps

Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,19 +560,17 @@ class BeitPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.data, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, BeitEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BeitEncoder):
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,19 +573,17 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.data, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Data2VecVisionEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Data2VecVisionEncoder):
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/deit/configuration_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class DeiTConfig(PretrainedConfig):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
cls_token_initializer_range (`float`, *optional*, defaults to 1e-6):
The standard deviation of the normal initializer for initializing the `cls_token` parameter.
distillation_token_initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing the `distillation_token`
parameter.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to `224`):
Expand Down Expand Up @@ -102,6 +107,8 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
cls_token_initializer_range=1e-6,
distillation_token_initializer_range=0.02,
initializer_range=0.02,
layer_norm_eps=1e-12,
image_size=224,
Expand All @@ -120,6 +127,8 @@ def __init__(
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.cls_token_initializer_range = cls_token_initializer_range
self.distillation_token_initializer_range = distillation_token_initializer_range
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ def custom_forward(*inputs):
)


# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing
class DeiTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand All @@ -400,9 +399,9 @@ class DeiTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
def _init_weights(self, module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the equivalent initialization updates in the TF model?

"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
Expand All @@ -413,6 +412,14 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DeiTEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
torch.nn.init.trunc_normal_(
module.distillation_token.data, std=self.config.distillation_token_initializer_range
)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None:
if isinstance(module, DeiTEncoder):
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,17 +794,20 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True

def _init_weights(self, module):
def _init_weights(self, module) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DonutSwinEmbeddings):
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DonutSwinEncoder):
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/maskformer/modeling_maskformer_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,17 +739,20 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True

def _init_weights(self, module):
def _init_weights(self, module) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, MaskFormerSwinEmbeddings):
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MaskFormerSwinEncoder):
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,17 +857,20 @@ class SwinPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True

def _init_weights(self, module):
def _init_weights(self, module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the equivalent initialization updates in the TF model?

"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, SwinEmbeddings):
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, SwinEncoder):
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/swinv2/modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,17 +937,20 @@ class Swinv2PreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True

def _init_weights(self, module):
def _init_weights(self, module) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Swinv2Embeddings):
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
if module.position_embeddings is not None:
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Swinv2Encoder):
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/vit/configuration_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ViTConfig(PretrainedConfig):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
cls_token_initializer_range (`float`, *optional*, defaults to 1e-6):
The standard deviation of the normal initializer for initializing the `cls_token` parameter.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
cls_token_initializer_range=1e-6,
initializer_range=0.02,
layer_norm_eps=1e-12,
image_size=224,
Expand All @@ -117,6 +120,7 @@ def __init__(
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.cls_token_initializer_range = cls_token_initializer_range
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ class ViTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
def _init_weights(self, module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the equivalent updates for the TF model?

"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
Expand All @@ -461,6 +461,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder):
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
def _init_weights(self, module) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
Expand All @@ -485,6 +485,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTHybridEmbeddings):
module.cls_token.data.normal_(mean=0.0, std=self.config.cls_token_initializer_range)
if module.mask_token is not None:
torch.nn.init.trunc_normal_(module.mask_token.data, std=self.config.initializer_range)
torch.nn.init.trunc_normal_(module.position_embeddings.data, std=self.config.initializer_range)

def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
if isinstance(module, ViTHybridEncoder):
Expand Down
Loading