Skip to content

Commit 72c1982

Browse files
committed
Add some more asserts to cuDNN RNN
1 parent 0de2ea3 commit 72c1982

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torch/backends/cudnn/rnn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def forward(fn, input, hx, weight, output, hy):
206206
fn.seq_length, fn.mini_batch, fn.input_size = input.size()
207207
hidden_size = _hidden_size(fn)
208208
output_size = _output_size(fn)
209+
210+
assert hx.is_contiguous()
211+
assert cx is None or cx.is_contiguous()
209212
x = input.contiguous()
210213
output.resize_(*output_size)
211214
hy.resize_(*hidden_size)
@@ -319,6 +322,8 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu
319322
hidden_size = _hidden_size(fn)
320323
output_size = _output_size(fn)
321324

325+
assert hx.is_contiguous()
326+
assert cx is None or cx.is_contiguous()
322327
x = input.contiguous()
323328
dy = grad_output.contiguous()
324329
y = output
@@ -397,6 +402,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
397402
hx, cx = hx
398403
else:
399404
cx = None
405+
400406
if fn.batch_first:
401407
input = input.transpose(0, 1)
402408
output = output.transpose(0, 1)
@@ -409,12 +415,12 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
409415
if tuple(input.size()) != input_size:
410416
raise RuntimeError('Expected input size {}, got {}'.format(
411417
input_size, tuple(input.size())))
412-
if not fn.train:
413-
raise RuntimeError('backward_weight can only be called when training!')
414418
if tuple(hx.size()) != hidden_size:
415419
raise RuntimeError('Expected input size {}, got {}'.format(
416420
hidden_size, hx.size()))
417421

422+
assert hx.is_contiguous()
423+
assert cx is None or cx.is_contiguous()
418424
x = input.contiguous()
419425
y = output
420426
dw = fn.weight_buf.new().resize_as_(fn.weight_buf).zero_()

0 commit comments

Comments
 (0)