There was an error while loading. Please reload this page.
1 parent 4c474a9 commit 5b10411Copy full SHA for 5b10411
torch/nn/modules/rnn.py
@@ -441,7 +441,7 @@ class LSTMCell(RNNCellBase):
441
>>> cx = Variable(torch.randn(3, 20))
442
>>> output = []
443
>>> for i in range(6):
444
- ... hx, cx = rnn(input, (hx, cx))
+ ... hx, cx = rnn(input[i], (hx, cx))
445
... output.append(hx)
446
"""
447
@@ -510,8 +510,8 @@ class GRUCell(RNNCellBase):
510
>>> hx = Variable(torch.randn(3, 20))
511
512
513
- ... hx = rnn(input, hx)
514
- ... output[i] = hx
+ ... hx = rnn(input[i], hx)
+ ... output.append(hx)
515
516
517
def __init__(self, input_size, hidden_size, bias=True):
0 commit comments