Skip to content

Commit 8262920

Browse files
ezyangsoumith
authored andcommitted
Add ATen overload to AutoGPU. (pytorch#2234)
* Add ATen overload to AutoGPU. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Use new AutoGPU overload. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent 0cd149f commit 8262920

File tree

6 files changed

+19
-13
lines changed

6 files changed

+19
-13
lines changed

torch/csrc/autograd/functions/accumulate_grad.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ auto AccumulateGrad::acc_inplace(std::shared_ptr<Variable>& grad,
1919
std::shared_ptr<Variable>& new_grad) -> void {
2020
auto& grad_data = grad->data;
2121
auto& new_grad_data = new_grad->data;
22-
AutoGPU guard(grad_data.type().isCuda() ? grad_data.get_device() : -1);
22+
AutoGPU guard(grad_data);
2323

2424
if (grad_data.type().isSparse() && !new_grad_data.type().isSparse()) {
2525
grad->data = new_grad_data + grad_data;

torch/csrc/autograd/functions/basic_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ auto Add::apply(const variable_list& inputs) -> variable_list {
2525
check_input_variables("Add", inputs, 2);
2626
auto& input1 = inputs[0]->data;
2727
auto& input2 = inputs[1]->data;
28-
AutoGPU guard(input1.type().isCuda() ? input1.get_device() : -1);
28+
AutoGPU guard(input1);
2929

3030
at::Tensor output;
3131
if (input1.type().isSparse()) {

torch/csrc/autograd/functions/batch_normalization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
3939
auto& input = inputs[0];
4040
auto& weight = inputs[1];
4141
auto& bias = inputs[2];
42-
AutoGPU guard(input->data.type().isCuda() ? input->data.get_device() : -1);
42+
AutoGPU guard(input->data);
4343

4444
auto num_features = input->data.sizes()[1];
4545
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
@@ -117,7 +117,7 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
117117
auto weight = weight_var ? weight_var->data : at::Tensor();
118118
auto bias = bias_var ? bias_var->data : at::Tensor();
119119

120-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
120+
AutoGPU guard(input);
121121

122122
bool use_cudnn = false;
123123
#ifdef WITH_CUDNN

torch/csrc/autograd/functions/convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ auto ConvForward::apply(const variable_list& inputs) -> variable_list {
148148
check_input_variables("ConvNd", inputs, 3, 2);
149149
if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
150150
if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
151-
AutoGPU guard(inputs[0]->data.type().isCuda() ? inputs[0]->data.get_device() : -1);
151+
AutoGPU guard(inputs[0]->data);
152152
auto input = inputs[0]->data.contiguous();
153153
auto weight = inputs[1]->data;
154154
auto bias = inputs[2] ? inputs[2]->data : at::Tensor();
@@ -249,7 +249,7 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
249249
auto weight = weight_var->data;
250250
auto bias = bias_var ? bias_var->data : at::Tensor();
251251

252-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
252+
AutoGPU guard(input);
253253

254254
input = input.contiguous();
255255
auto grad_output = grad_outputs[0]->data.contiguous();

torch/csrc/autograd/functions/tensor.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ auto Identity::apply(const variable_list& inputs) -> variable_list {
1414
auto Clone::apply(const variable_list& inputs) -> variable_list {
1515
check_input_variables("Clone", inputs, 1);
1616
auto& input = inputs[0]->data;
17-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
17+
AutoGPU guard(input);
1818

1919
at::Tensor output = input.clone();
2020

@@ -26,7 +26,7 @@ auto Clone::apply(const variable_list& inputs) -> variable_list {
2626
auto Contiguous::apply(const variable_list& inputs) -> variable_list {
2727
check_input_variables("Contiguous", inputs, 1);
2828
auto& input = inputs[0]->data;
29-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
29+
AutoGPU guard(input);
3030

3131
at::Tensor output = input.contiguous();
3232

@@ -39,7 +39,7 @@ auto Transpose::apply(const variable_list& inputs) -> variable_list {
3939
check_input_variables("Transpose", inputs, 1);
4040

4141
auto& input = inputs[0]->data;
42-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
42+
AutoGPU guard(input);
4343

4444
at::Tensor output = input.transpose(dim1, dim2);
4545

@@ -52,7 +52,7 @@ auto View::apply(const variable_list& inputs) -> variable_list {
5252
check_input_variables("View", inputs, 1);
5353

5454
auto& input = inputs[0]->data;
55-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
55+
AutoGPU guard(input);
5656

5757
at::Tensor output = input.view(size);
5858

@@ -65,7 +65,7 @@ auto Expand::apply(const variable_list& inputs) -> variable_list {
6565
check_input_variables("Expand", inputs, 1);
6666

6767
auto& input = inputs[0]->data;
68-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
68+
AutoGPU guard(input);
6969

7070
at::Tensor output = input.expand(size);
7171

@@ -78,7 +78,7 @@ auto Narrow::apply(const variable_list& inputs) -> variable_list {
7878
check_input_variables("Narrow", inputs, 1);
7979

8080
auto& input = inputs[0]->data;
81-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
81+
AutoGPU guard(input);
8282

8383
at::Tensor output = input.narrow(dim, start, size);
8484

@@ -94,7 +94,7 @@ auto Cat::apply(const variable_list& inputs) -> variable_list {
9494
}
9595

9696
auto& input = inputs[0]->data;
97-
AutoGPU guard(input.type().isCuda() ? input.get_device() : -1);
97+
AutoGPU guard(input);
9898

9999
std::vector<at::Tensor> tensors(num_inputs);
100100
for (int i = 0; i < num_inputs; ++i) {

torch/csrc/utils/auto_gpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <string>
66
#include <stdexcept>
77

8+
#include <ATen/ATen.h>
9+
810
#ifdef WITH_CUDA
911
#include <cuda.h>
1012
#include <cuda_runtime.h>
@@ -15,6 +17,10 @@ struct AutoGPU {
1517
setDevice(device);
1618
}
1719

20+
explicit AutoGPU(const at::Tensor& t) {
21+
setDevice(t.type().isCuda() ? t.get_device() : -1);
22+
}
23+
1824
~AutoGPU() {
1925
#ifdef WITH_CUDA
2026
if (original_device != -1) {

0 commit comments

Comments
 (0)