@@ -560,19 +560,17 @@ class BeitPreTrainedModel(PreTrainedModel):
560560
561561 def _init_weights (self , module ):
562562 """Initialize the weights"""
563- if isinstance (module , (nn .Linear , nn .Conv2d , nn .ConvTranspose2d )):
564- # Slightly different from the TF version which uses truncated_normal for initialization
565- # cf https://github.com/pytorch/pytorch/pull/5617
566- module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
567- if module .bias is not None :
568- module .bias .data .zero_ ()
569- elif isinstance (module , nn .Embedding ):
570- module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
571- if module .padding_idx is not None :
572- module .weight .data [module .padding_idx ].zero_ ()
563+ if isinstance (module , nn .Linear ):
564+ torch .nn .init .trunc_normal_ (module .data , std = self .config .initializer_range )
573565 elif isinstance (module , nn .LayerNorm ):
574566 module .bias .data .zero_ ()
575567 module .weight .data .fill_ (1.0 )
568+ elif isinstance (module , BeitEmbeddings ):
569+ module .cls_token .data .normal_ (mean = 0.0 , std = self .config .cls_token_initializer_range )
570+ if module .mask_token is not None :
571+ torch .nn .init .trunc_normal_ (module .mask_token .data , std = self .config .initializer_range )
572+ if module .position_embeddings is not None :
573+ torch .nn .init .trunc_normal_ (module .position_embeddings .data , std = self .config .initializer_range )
576574
577575 def _set_gradient_checkpointing (self , module , value = False ):
578576 if isinstance (module , BeitEncoder ):
0 commit comments