|
| 1 | +# _*_ coding: utf-8 _*_ |
1 | 2 |
|
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch.autograd import Variable |
| 6 | +from torch.nn import functional as F |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +class AttentionModel(torch.nn.Module): |
| 10 | +def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights): |
| 11 | +super(AttentionModel, self).__init__() |
| 12 | + |
| 13 | +""" |
| 14 | +Arguments |
| 15 | +--------- |
| 16 | +batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator |
| 17 | +output_size : 2 = (pos, neg) |
| 18 | +hidden_sie : Size of the hidden_state of the LSTM |
| 19 | +vocab_size : Size of the vocabulary containing unique words |
| 20 | +embedding_length : Embeddding dimension of GloVe word embeddings |
| 21 | +weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table |
| 22 | +
|
| 23 | +-------- |
| 24 | +
|
| 25 | +""" |
| 26 | + |
| 27 | +self.batch_size = batch_size |
| 28 | +self.output_size = output_size |
| 29 | +self.hidden_size = hidden_size |
| 30 | +self.vocab_size = vocab_size |
| 31 | +self.embedding_length = embedding_length |
| 32 | + |
| 33 | +self.word_embeddings = nn.Embedding(vocab_size, embedding_length) |
| 34 | +self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False) |
| 35 | +self.lstm = nn.LSTM(embedding_length, hidden_size) |
| 36 | +self.label = nn.Linear(hidden_size, output_size) |
| 37 | +#self.attn_fc_layer = nn.Linear() |
| 38 | + |
| 39 | +def attention_net(self, lstm_output, final_state): |
| 40 | + |
| 41 | +""" |
| 42 | +Now we will incorporate Attention mechanism in our LSTM model. In this new model, we will use attention to compute soft alignment score corresponding |
| 43 | +between each of the hidden_state and the last hidden_state of the LSTM. We will be using torch.bmm for the batch matrix multiplication. |
| 44 | +
|
| 45 | +Arguments |
| 46 | +--------- |
| 47 | +
|
| 48 | +lstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence. |
| 49 | +final_state : Final time-step hidden state (h_n) of the LSTM |
| 50 | +
|
| 51 | +--------- |
| 52 | +
|
| 53 | +Returns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the |
| 54 | + new hidden state. |
| 55 | + |
| 56 | +Tensor Size : |
| 57 | + hidden.size() = (batch_size, hidden_size) |
| 58 | + attn_weights.size() = (batch_size, num_seq) |
| 59 | + soft_attn_weights.size() = (batch_size, num_seq) |
| 60 | + new_hidden_state.size() = (batch_size, hidden_size) |
| 61 | + |
| 62 | +""" |
| 63 | + |
| 64 | +hidden = final_state.squeeze(0) |
| 65 | +attn_weights = torch.bmm(lstm_output, hidden.unsqueeze(2)).squeeze(2) |
| 66 | +soft_attn_weights = F.softmax(attn_weights, 1) |
| 67 | +new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2) |
| 68 | + |
| 69 | +return new_hidden_state |
| 70 | + |
| 71 | +def forward(self, input_sentences, batch_size=None): |
| 72 | + |
| 73 | +""" |
| 74 | +Parameters |
| 75 | +---------- |
| 76 | +input_sentence: input_sentence of shape = (batch_size, num_sequences) |
| 77 | +batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) |
| 78 | +
|
| 79 | +Returns |
| 80 | +------- |
| 81 | +Output of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network. |
| 82 | +final_output.shape = (batch_size, output_size) |
| 83 | +
|
| 84 | +""" |
| 85 | + |
| 86 | +input = self.word_embeddings(input_sentences) |
| 87 | +input = input.permute(1, 0, 2) |
| 88 | +if batch_size is None: |
| 89 | +h_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) |
| 90 | +c_0 = Variable(torch.zeros(1, self.batch_size, self.hidden_size).cuda()) |
| 91 | +else: |
| 92 | +h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) |
| 93 | +c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) |
| 94 | + |
| 95 | +output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) # final_hidden_state.size() = (1, batch_size, hidden_size) |
| 96 | +output = output.permute(1, 0, 2) # output.size() = (batch_size, num_seq, hidden_size) |
| 97 | + |
| 98 | +attn_output = self.attention_net(self, output, final_hidden_state) |
| 99 | +logits = self.label(attn_output) |
| 100 | + |
| 101 | +return logits |
0 commit comments