Skip to content

Commit d51cd61

Browse files
alykhantejanisoumith
authored andcommitted
add checks for input, weight and bias types when using cudnn conv2d (pytorch#1689)
1 parent 447fe95 commit d51cd61

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

test/test_nn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,35 @@ def test_Conv2d_inconsistent_types(self):
13171317
# but it should work with the same type
13181318
nn.functional.conv2d(inputs.float(), weights.float())
13191319

1320+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
1321+
def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
1322+
inputs = Variable(torch.randn(4, 1, 7, 7).float().cuda())
1323+
weights = Variable(torch.randn(1, 1, 3, 3).double().cuda())
1324+
bias = Variable(torch.randn(1).double().cuda())
1325+
1326+
torch.backends.cudnn.enabled = False
1327+
# inconsistent types should raise an exception
1328+
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
1329+
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias))
1330+
1331+
# but it should work with the same type
1332+
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
1333+
1334+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
1335+
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
1336+
def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
1337+
inputs = Variable(torch.randn(4, 1, 7, 7).float().cuda())
1338+
weights = Variable(torch.randn(1, 1, 3, 3).double().cuda())
1339+
bias = Variable(torch.randn(1).double().cuda())
1340+
1341+
torch.backends.cudnn.enabled = True
1342+
# inconsistent types should raise an exception
1343+
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
1344+
self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias))
1345+
1346+
# but it should work with the same type
1347+
nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
1348+
13201349
def test_Conv2d_missing_argument(self):
13211350
c = nn.Conv2d(3, 3, 3)
13221351
self.assertRaises(RuntimeError, lambda: c(None))

torch/csrc/autograd/functions/convolution.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <sstream>
2+
13
#include "convolution.h"
24

35
#include "torch/csrc/autograd/variable.h"
@@ -6,6 +8,8 @@
68
#include "torch/csrc/nn/THNN_generic.h"
79
#include "torch/csrc/utils/auto_gpu.h"
810

11+
#include "THPP/Type.hpp"
12+
913
#ifdef WITH_CUDNN
1014
#include "torch/csrc/cudnn/Conv.h"
1115
#include "torch/csrc/cudnn/Handles.h"
@@ -120,11 +124,11 @@ static auto view3d(const Tensor& tensor) -> std::unique_ptr<Tensor> {
120124
return result;
121125
}
122126

127+
123128
auto ConvForward::apply(const variable_list& inputs) -> variable_list {
124129
check_input_variables("ConvNd", inputs, 3, 2);
125130
if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
126131
if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
127-
128132
AutoGPU guard(inputs[0]->data->getDevice());
129133
auto input = inputs[0]->data->contiguous();
130134
std::unique_ptr<Tensor> weight(inputs[1]->data->clone_shallow());
@@ -147,6 +151,16 @@ auto ConvForward::apply(const variable_list& inputs) -> variable_list {
147151

148152
if (use_cudnn(*input)) {
149153
#ifdef WITH_CUDNN
154+
if (input->type() != weight->type()){
155+
std::stringstream ss;
156+
ss << "Input type (" << thpp::toString(input->type()) << ") and weight type (" << thpp::toString(weight->type()) << ") should be the same";
157+
throw std::runtime_error(ss.str());
158+
}
159+
if (bias.get() != NULL && input->type() != bias->type()){
160+
std::stringstream ss;
161+
ss << "Input type (" << thpp::toString(input->type()) << ") and bias type (" << thpp::toString(bias->type()) << ") should be the same";
162+
throw std::runtime_error(ss.str());
163+
}
150164
output = input->newTensor();
151165
output->resize(output_size(*input, *weight));
152166
if (transposed) {
@@ -203,7 +217,6 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
203217
check_input_variables("ConvNdBackward", grad_outputs, 1);
204218
if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
205219
if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
206-
207220
auto input = input_.unpack_data();
208221
AutoGPU guard(input->getDevice());
209222
input = input->contiguous();

0 commit comments

Comments
 (0)