Skip to content

Commit 2b56711

Browse files
csarofeensoumith
authored andcommitted
Indexing fix for fused GRU/LSTM kernels when all tensors are not contiguous. (pytorch#1325)
1 parent 2fa3365 commit 2b56711

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torch/nn/_functions/thnn/rnnFusedPointwise.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ def forward(self, input_gate, hidden_gate, hx, ibias=None, hbias=None):
1111
if self.backend is None:
1212
self.backend = type2backend[type(input_gate)]
1313
hy = input_gate.new()
14+
if ibias is not None:
15+
if ibias.dim() == 1:
16+
ibias.unsqueeze_(0)
17+
if hbias.dim() == 1:
18+
hbias.unsqueeze_(0)
19+
1420
self.backend.GRUFused_updateOutput(
1521
self.backend.library_state,
1622
input_gate, hidden_gate, ibias, hbias, hx, hy)
@@ -44,6 +50,11 @@ def forward(self, input_gate, hidden_gate, cx, ibias=None, hbias=None):
4450
self.backend = type2backend[type(input_gate)]
4551
hy = input_gate.new()
4652
cy = input_gate.new()
53+
if ibias is not None:
54+
if ibias.dim() == 1:
55+
ibias.unsqueeze_(0)
56+
if hbias.dim() == 1:
57+
hbias.unsqueeze_(0)
4758
self.backend.LSTMFused_updateOutput(
4859
self.backend.library_state,
4960
input_gate, hidden_gate,

0 commit comments

Comments
 (0)