Skip to content

Commit 5b52c92

Browse files
committed
Transformer (OpenAI version) update
- "Block" definition is implemented (Ref: huggingface and openai) - LanguageModelHead is implemented. - This part of the project still under construction!
1 parent 1510acb commit 5b52c92

File tree

1 file changed

+113
-12
lines changed

1 file changed

+113
-12
lines changed

models/Transformer_OpenAI.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,6 @@
1111
from utils.utils import clones
1212

1313

14-
class LayerNorm(nn.Module):
15-
def __init__(self, features, epsilon=1e-5):
16-
self.a_2 = nn.Parameter(torch.ones(features))
17-
self.b_2 = nn.Parameter(torch.zeros(features))
18-
self.epsilon = epsilon
19-
20-
def forward(self, x):
21-
mean = x.mean(-1, keepdim=True)
22-
std = x.std(-1, keepdim=True)
23-
return self.a_2 * (x - mean) / torch.sqrt(std + self.epsilon) + self.b_2
24-
2514

2615
class Embeddings(nn.Module):
2716
def __init__(self, embed_dim, vocab_size, keep_prob, padding_id, use_pretrained_embed, pretrained_weights):
@@ -38,9 +27,121 @@ def forward(self, input):
3827
return out
3928

4029

30+
class LayerNorm(nn.Module):
31+
def __init__(self, features, epsilon=1e-5):
32+
self.a_2 = nn.Parameter(torch.ones(features))
33+
self.b_2 = nn.Parameter(torch.zeros(features))
34+
self.epsilon = epsilon
35+
36+
def forward(self, x):
37+
mean = x.mean(-1, keepdim=True)
38+
std = x.std(-1, keepdim=True)
39+
return self.a_2 * (x - mean) / torch.sqrt(std + self.epsilon) + self.b_2
40+
41+
42+
class MultiLayerPerceptron(nn.Module):
43+
def __init__(self, num_state, embed_dim, keep_prob):
44+
self.fc = nn.Conv1d(num_state, 1, embed_dim)
45+
self.proj = nn.Conv1d(embed_dim, 1, num_state)
46+
self.activation = nn.ReLU()
47+
self.dropout = nn.Dropout(keep_prob)
48+
49+
def forward(self, input):
50+
x = self.activation(self.fc(input))
51+
x = self.dropout(self.proj(x))
52+
return x
53+
54+
55+
class ModifiedMultiHeadedAttention(nn.Module):
56+
def __init__(self, num_state, n_ctx, num_heads, keep_prob_attention, keep_prob_residual, scale=False):
57+
assert num_state % num_heads == 0
58+
self.bias = torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)
59+
self.num_heads = num_heads
60+
self.split_size = num_state
61+
self.scale = scale
62+
self.attn = nn.Conv1d(num_state * 3, 1, num_state)
63+
self.proj = nn.Conv1d(num_state, 1, num_state)
64+
self.attn_dropout = nn.Dropout(keep_prob_attention)
65+
self.residual_dropout = nn.Dropout(keep_prob_residual)
66+
67+
def attention(self, query, key, value):
68+
weight = torch.matmul(query, key)
69+
if self.scale:
70+
weight = weight / math.sqrt(value.size(-1))
71+
72+
# Mask attention weights
73+
bias = self.bias[:, :, :weight.size(-2), :weight.size(-1)]
74+
weight = weight * bias - 1e9 * (1 - bias)
75+
76+
p_attn = F.softmax(weight, dim=-1)
77+
if self.attn_dropout is not None:
78+
p_attn = self.attn_dropout(p_attn)
79+
return torch.matmul(p_attn, value)
80+
81+
# Direct c/p from huggingface, which is the equivalent of original tensorflow implementation.
82+
def merge_heads(self, x):
83+
x = x.permute(0, 2, 1, 3)
84+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
85+
return x.view(*new_x_shape)
86+
87+
# Direct c/p from huggingface, which is the equivalent of original tensorflow implementation.
88+
def split_heads(self, x, is_key=False):
89+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
90+
x = x.view(*new_x_shape)
91+
if is_key:
92+
return x.permute(0, 2, 3, 1)
93+
else:
94+
return x.permute(0, 2, 1, 3)
95+
96+
def forward(self, input):
97+
x = self.attn(input)
98+
query, key, value = x.split(self.split_size, dim=2)
99+
query = self.split_heads(query)
100+
key = self.split_heads(key, is_key=True)
101+
value = self.split_heads(value)
102+
out = self.proj(self.merge_heads(self.attention(query, key, value)))
103+
return self.residual_dropout(out)
104+
105+
106+
class Block(nn.Module):
107+
def __init__(self, embed_dim, num_heads, keep_prob_attention, keep_prob_residual, keep_prob_mlp, n_ctx=512,
108+
scale=False, use_builtin_mha=False):
109+
if use_builtin_mha:
110+
self.attention = nn.MultiheadAttention(embed_dim=embed_dim,
111+
num_heads=num_heads,
112+
dropout=keep_prob_attention)
113+
else:
114+
self.attention = ModifiedMultiHeadedAttention(num_state=embed_dim,
115+
n_ctx=n_ctx,
116+
num_heads=num_heads,
117+
keep_prob_attention=keep_prob_attention,
118+
keep_prob_residual=keep_prob_residual,
119+
scale=scale)
120+
self.layer_norm1 = LayerNorm(embed_dim)
121+
self.mlp = MultiLayerPerceptron(4 * embed_dim, embed_dim, keep_prob_mlp)
122+
self.layer_norm2 = LayerNorm(embed_dim)
123+
124+
def forward(self, input):
125+
x = self.attn(input)
126+
x_hat = self.ln_1(input + x)
127+
x = self.mlp(x_hat)
128+
x = self.ln_2(x_hat + x)
129+
return x
130+
41131
class LanguageModelHead(nn.Module):
42-
def __init__(self):
132+
def __init__(self, embedding, embed_dim):
43133
super(LanguageModelHead, self).__init__()
134+
self.embed_dim = embed_dim
135+
self.decoder = nn.Linear(in_features=embedding.embedding.weight.shape[1],
136+
out_features=embedding.embedding.weight.shape[0],
137+
bias=True)
138+
self.decoder.weight = embedding.embedding.weight
139+
140+
def forward(self, input):
141+
# Remove last token
142+
x = input[:, :-1].view(-1, self.embed_dim)
143+
x = self.decoder(x)
144+
return x
44145

45146

46147
class TransformerOpenAI:

0 commit comments

Comments
 (0)