Skip to content

Commit e8d5bbd

Browse files
cuevasclementeapaszke
authored andcommitted
Command Line Interface backwards compatible fix for models.py (#85)
* commandline backwards compatible fix for models.py * changes formatting to accomodate a 120 char width
1 parent 409a726 commit e8d5bbd

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

word_language_model/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@ class RNNModel(nn.Module):
77
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers):
88
super(RNNModel, self).__init__()
99
self.encoder = nn.Embedding(ntoken, ninp)
10-
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, bias=False)
10+
if rnn_type in ['LSTM', 'GRU']:
11+
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, bias=False)
12+
else:
13+
try:
14+
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
15+
except KeyError:
16+
raise ValueError( """An invalid option for `--model` was supplied,
17+
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
18+
self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, bias=False)
1119
self.decoder = nn.Linear(nhid, ntoken)
1220

1321
self.init_weights()

0 commit comments

Comments
 (0)