1111from utils .utils import clones
1212
1313
14- class LayerNorm (nn .Module ):
15- def __init__ (self , features , epsilon = 1e-5 ):
16- self .a_2 = nn .Parameter (torch .ones (features ))
17- self .b_2 = nn .Parameter (torch .zeros (features ))
18- self .epsilon = epsilon
19-
20- def forward (self , x ):
21- mean = x .mean (- 1 , keepdim = True )
22- std = x .std (- 1 , keepdim = True )
23- return self .a_2 * (x - mean ) / torch .sqrt (std + self .epsilon ) + self .b_2
24-
2514
2615class Embeddings (nn .Module ):
2716 def __init__ (self , embed_dim , vocab_size , keep_prob , padding_id , use_pretrained_embed , pretrained_weights ):
@@ -38,9 +27,121 @@ def forward(self, input):
3827 return out
3928
4029
30+ class LayerNorm (nn .Module ):
31+ def __init__ (self , features , epsilon = 1e-5 ):
32+ self .a_2 = nn .Parameter (torch .ones (features ))
33+ self .b_2 = nn .Parameter (torch .zeros (features ))
34+ self .epsilon = epsilon
35+
36+ def forward (self , x ):
37+ mean = x .mean (- 1 , keepdim = True )
38+ std = x .std (- 1 , keepdim = True )
39+ return self .a_2 * (x - mean ) / torch .sqrt (std + self .epsilon ) + self .b_2
40+
41+
42+ class MultiLayerPerceptron (nn .Module ):
43+ def __init__ (self , num_state , embed_dim , keep_prob ):
44+ self .fc = nn .Conv1d (num_state , 1 , embed_dim )
45+ self .proj = nn .Conv1d (embed_dim , 1 , num_state )
46+ self .activation = nn .ReLU ()
47+ self .dropout = nn .Dropout (keep_prob )
48+
49+ def forward (self , input ):
50+ x = self .activation (self .fc (input ))
51+ x = self .dropout (self .proj (x ))
52+ return x
53+
54+
55+ class ModifiedMultiHeadedAttention (nn .Module ):
56+ def __init__ (self , num_state , n_ctx , num_heads , keep_prob_attention , keep_prob_residual , scale = False ):
57+ assert num_state % num_heads == 0
58+ self .bias = torch .tril (torch .ones (n_ctx , n_ctx )).view (1 , 1 , n_ctx , n_ctx )
59+ self .num_heads = num_heads
60+ self .split_size = num_state
61+ self .scale = scale
62+ self .attn = nn .Conv1d (num_state * 3 , 1 , num_state )
63+ self .proj = nn .Conv1d (num_state , 1 , num_state )
64+ self .attn_dropout = nn .Dropout (keep_prob_attention )
65+ self .residual_dropout = nn .Dropout (keep_prob_residual )
66+
67+ def attention (self , query , key , value ):
68+ weight = torch .matmul (query , key )
69+ if self .scale :
70+ weight = weight / math .sqrt (value .size (- 1 ))
71+
72+ # Mask attention weights
73+ bias = self .bias [:, :, :weight .size (- 2 ), :weight .size (- 1 )]
74+ weight = weight * bias - 1e9 * (1 - bias )
75+
76+ p_attn = F .softmax (weight , dim = - 1 )
77+ if self .attn_dropout is not None :
78+ p_attn = self .attn_dropout (p_attn )
79+ return torch .matmul (p_attn , value )
80+
81+ # Direct c/p from huggingface, which is the equivalent of original tensorflow implementation.
82+ def merge_heads (self , x ):
83+ x = x .permute (0 , 2 , 1 , 3 )
84+ new_x_shape = x .size ()[:- 2 ] + (x .size (- 2 ) * x .size (- 1 ),)
85+ return x .view (* new_x_shape )
86+
87+ # Direct c/p from huggingface, which is the equivalent of original tensorflow implementation.
88+ def split_heads (self , x , is_key = False ):
89+ new_x_shape = x .size ()[:- 1 ] + (self .n_head , x .size (- 1 ) // self .n_head )
90+ x = x .view (* new_x_shape )
91+ if is_key :
92+ return x .permute (0 , 2 , 3 , 1 )
93+ else :
94+ return x .permute (0 , 2 , 1 , 3 )
95+
96+ def forward (self , input ):
97+ x = self .attn (input )
98+ query , key , value = x .split (self .split_size , dim = 2 )
99+ query = self .split_heads (query )
100+ key = self .split_heads (key , is_key = True )
101+ value = self .split_heads (value )
102+ out = self .proj (self .merge_heads (self .attention (query , key , value )))
103+ return self .residual_dropout (out )
104+
105+
106+ class Block (nn .Module ):
107+ def __init__ (self , embed_dim , num_heads , keep_prob_attention , keep_prob_residual , keep_prob_mlp , n_ctx = 512 ,
108+ scale = False , use_builtin_mha = False ):
109+ if use_builtin_mha :
110+ self .attention = nn .MultiheadAttention (embed_dim = embed_dim ,
111+ num_heads = num_heads ,
112+ dropout = keep_prob_attention )
113+ else :
114+ self .attention = ModifiedMultiHeadedAttention (num_state = embed_dim ,
115+ n_ctx = n_ctx ,
116+ num_heads = num_heads ,
117+ keep_prob_attention = keep_prob_attention ,
118+ keep_prob_residual = keep_prob_residual ,
119+ scale = scale )
120+ self .layer_norm1 = LayerNorm (embed_dim )
121+ self .mlp = MultiLayerPerceptron (4 * embed_dim , embed_dim , keep_prob_mlp )
122+ self .layer_norm2 = LayerNorm (embed_dim )
123+
124+ def forward (self , input ):
125+ x = self .attn (input )
126+ x_hat = self .ln_1 (input + x )
127+ x = self .mlp (x_hat )
128+ x = self .ln_2 (x_hat + x )
129+ return x
130+
41131class LanguageModelHead (nn .Module ):
42- def __init__ (self ):
132+ def __init__ (self , embedding , embed_dim ):
43133 super (LanguageModelHead , self ).__init__ ()
134+ self .embed_dim = embed_dim
135+ self .decoder = nn .Linear (in_features = embedding .embedding .weight .shape [1 ],
136+ out_features = embedding .embedding .weight .shape [0 ],
137+ bias = True )
138+ self .decoder .weight = embedding .embedding .weight
139+
140+ def forward (self , input ):
141+ # Remove last token
142+ x = input [:, :- 1 ].view (- 1 , self .embed_dim )
143+ x = self .decoder (x )
144+ return x
44145
45146
46147class TransformerOpenAI :
0 commit comments