Skip to content

Commit 3dda0be

Browse files
Update LSTM_Attn.py
1 parent 5c196e1 commit 3dda0be

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

models/LSTM_Attn.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,101 @@
1+
# _*_ coding: utf-8 _*_
12

3+
import torch
4+
import torch.nn as nn
5+
from torch.autograd import Variable
6+
from torch.nn import functional as F
7+
import numpy as np
8+
9+
class AttentionModel(torch.nn.Module):
10+
def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):
11+
super(AttentionModel, self).__init__()
12+
13+
"""
14+
Arguments
15+
---------
16+
batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator
17+
output_size : 2 = (pos, neg)
18+
hidden_sie : Size of the hidden_state of the LSTM
19+
vocab_size : Size of the vocabulary containing unique words
20+
embedding_length : Embeddding dimension of GloVe word embeddings
21+
weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table
22+
23+
--------
24+
25+
"""
26+
27+
self.batch_size = batch_size
28+
self.output_size = output_size
29+
self.hidden_size = hidden_size
30+
self.vocab_size = vocab_size
31+
self.embedding_length = embedding_length
32+
33+
self.word_embeddings = nn.Embedding(vocab_size, embedding_length)
34+
self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False)
35+
self.lstm = nn.LSTM(embedding_length, hidden_size)
36+
self.label = nn.Linear(hidden_size, output_size)
37+
#self.attn_fc_layer = nn.Linear()
38+
39+
def attention_net(self, lstm_output, final_state):
40+
41+
"""
42+
Now we will incorporate Attention mechanism in our LSTM model. In this new model, we will use attention to compute soft alignment score corresponding
43+
between each of the hidden_state and the last hidden_state of the LSTM. We will be using torch.bmm for the batch matrix multiplication.
44+
45+
Arguments
46+
---------
47+
48+
lstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence.
49+
final_state : Final time-step hidden state (h_n) of the LSTM
50+
51+
---------
52+
53+
Returns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the
54+
new hidden state.
55+
56+
Tensor Size :
57+
hidden.size() = (batch_size, hidden_size)
58+
attn_weights.size() = (batch_size, num_seq)
59+
soft_attn_weights.size() = (batch_size, num_seq)
60+
new_hidden_state.size() = (batch_size, hidden_size)
61+
62+
"""
63+
64+
hidden = final_state.squeeze(0)
65+
attn_weights = torch.bmm(lstm_output, hidden.unsqueeze(2)).squeeze(2)
66+
soft_attn_weights = F.softmax(attn_weights, 1)
67+
new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
68+
69+
return new_hidden_state
70+
71+
def forward(self, input_sentences, batch_size=None):
72+
73+
"""
74+
Parameters
75+
----------
76+
input_sentence: input_sentence of shape = (batch_size, num_sequences)
77+
batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1)
78+
79+
Returns
80+
-------
81+
Output of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network.
82+
final_output.shape = (batch_size, output_size)
83+
84+
"""
85+
86+
input = self.word_embeddings(input_sentences)
87+
input = input.permute(1, 0, 2)
88+
if batch_size is None:
89+
h_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda())
90+
c_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda())
91+
else:
92+
h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
93+
c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())
94+
95+
output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) # final_hidden_state.size() = (1, batch_size, hidden_size)
96+
output = output.permute(1, 0, 2) # output.size() = (batch_size, num_seq, hidden_size)
97+
98+
attn_output = self.attention_net(self, output, final_hidden_state)
99+
logits = self.label(attn_output)
100+
101+
return logits

0 commit comments

Comments
 (0)