@@ -206,6 +206,9 @@ def forward(fn, input, hx, weight, output, hy):
206206 fn .seq_length , fn .mini_batch , fn .input_size = input .size ()
207207 hidden_size = _hidden_size (fn )
208208 output_size = _output_size (fn )
209+
210+ assert hx .is_contiguous ()
211+ assert cx is None or cx .is_contiguous ()
209212 x = input .contiguous ()
210213 output .resize_ (* output_size )
211214 hy .resize_ (* hidden_size )
@@ -319,6 +322,8 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu
319322 hidden_size = _hidden_size (fn )
320323 output_size = _output_size (fn )
321324
325+ assert hx .is_contiguous ()
326+ assert cx is None or cx .is_contiguous ()
322327 x = input .contiguous ()
323328 dy = grad_output .contiguous ()
324329 y = output
@@ -397,6 +402,7 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
397402 hx , cx = hx
398403 else :
399404 cx = None
405+
400406 if fn .batch_first :
401407 input = input .transpose (0 , 1 )
402408 output = output .transpose (0 , 1 )
@@ -409,12 +415,12 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
409415 if tuple (input .size ()) != input_size :
410416 raise RuntimeError ('Expected input size {}, got {}' .format (
411417 input_size , tuple (input .size ())))
412- if not fn .train :
413- raise RuntimeError ('backward_weight can only be called when training!' )
414418 if tuple (hx .size ()) != hidden_size :
415419 raise RuntimeError ('Expected input size {}, got {}' .format (
416420 hidden_size , hx .size ()))
417421
422+ assert hx .is_contiguous ()
423+ assert cx is None or cx .is_contiguous ()
418424 x = input .contiguous ()
419425 y = output
420426 dw = fn .weight_buf .new ().resize_as_ (fn .weight_buf ).zero_ ()
0 commit comments