Skip to content

Commit 8ca7bf2

Browse files
authored
Check argument types in 'checkTypes' (pytorch#1363)
Fixes pytorch#1357
1 parent 41705ce commit 8ca7bf2

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/test_nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,14 @@ def test_assignments(get_list, a, b, c):
12991299
self.assertIn('buf', l.state_dict())
13001300
self.assertIs(l.state_dict()['buf'], buf)
13011301

1302+
def test_Conv2d_inconsistent_types(self):
1303+
inputs = Variable(torch.randn(4, 1, 7, 7).float())
1304+
weights = Variable(torch.randn(1, 1, 3, 3).double())
1305+
# inconsistent types should raise an exception
1306+
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
1307+
# but it should work with the same type
1308+
nn.functional.conv2d(inputs.float(), weights.float())
1309+
13021310
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
13031311
def test_Conv2d_large_workspace(self):
13041312
# These sizes require huge cuDNN workspaces. Make sure we choose a

torch/csrc/nn/THNN_generic.inc.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ void checkTypes(bool isCuda, thpp::Type type, ...) {
5050
if (tensor->isCuda() != isCuda) {
5151
throw invalid_tensor(isCuda ? "CUDA" : "CPU", tensor->isCuda() ? "CUDA" : "CPU");
5252
}
53+
if (tensor->type() != type) {
54+
throw invalid_tensor(thpp::toString(type), thpp::toString(tensor->type()));
55+
}
5356
}
5457
}
5558

0 commit comments

Comments
 (0)