Skip to content

Commit 700109e

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
set stream everytime when we get a cuDNN handle (pytorch#31541)
Summary: cudnn version of pytorch#31537 pytorch#31532 is a quick fix and this is a bigger change. This would deprecate pytorch#31532, but we could also merge pytorch#31532 first for a quick fix and then work on this later. Pull Request resolved: pytorch#31541 Differential Revision: D19206753 Pulled By: ngimel fbshipit-source-id: 3352f923d13a9baf0971f64f8b7ce03e9a8b42b1
1 parent b5bbec7 commit 700109e

File tree

9 files changed

+4
-28
lines changed

9 files changed

+4
-28
lines changed

aten/src/ATen/cudnn/Descriptors.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ struct TORCH_CUDA_API DropoutDescriptor
191191
AT_ASSERT(options.device().type() == kCUDA);
192192
AT_ASSERT(options.dtype() == kByte);
193193
state = at::empty({static_cast<int64_t>(state_size)}, options);
194-
setCuDNNStreamToCurrent();
195194
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
196195
}
197196

@@ -202,7 +201,6 @@ struct TORCH_CUDA_API DropoutDescriptor
202201
void *state_ptr = state.data_ptr();
203202
size_t state_size = state.size(0);
204203
// NB: The seed doesn't actually matter, so we give a dummy value
205-
setCuDNNStreamToCurrent();
206204
AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
207205
}
208206

aten/src/ATen/cudnn/Handle.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <ATen/cudnn/Handle.h>
22
#include <ATen/cuda/detail/DeviceThreadHandles.h>
3+
#include <c10/cuda/CUDAStream.h>
34

45
namespace at { namespace native {
56
namespace {
@@ -40,7 +41,9 @@ cudnnHandle_t getCudnnHandle()
4041
if (!myPoolWindow)
4142
myPoolWindow.reset(pool.newPoolWindow());
4243

43-
return myPoolWindow->reserve(device);
44+
auto handle = myPoolWindow->reserve(device);
45+
AT_CUDNN_CHECK(cudnnSetStream(handle, c10::cuda::getCurrentCUDAStream()));
46+
return handle;
4447
}
4548

4649
}} // namespace at::native

aten/src/ATen/cudnn/Utils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
namespace at { namespace native {
1010

11-
inline void setCuDNNStreamToCurrent() {
12-
// TODO: Should getCurrentStream be a method on Context?
13-
AT_CUDNN_CHECK(cudnnSetStream(getCudnnHandle(), at::cuda::getCurrentCUDAStream()));
14-
}
15-
1611
// cuDNN has a buggy check for tensor being contiguous (that is, it does
1712
// not ignore stride for dimension that is equal to 0). This function
1813
// makes tensors which have zero stride contiguous, by setting the

aten/src/ATen/native/cudnn/AffineGridGenerator.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ Tensor cudnn_affine_grid_generator_forward(
5252
const Tensor& theta_t,
5353
int64_t N, int64_t C, int64_t H, int64_t W)
5454
{
55-
setCuDNNStreamToCurrent();
56-
5755
TensorArg theta{ theta_t.contiguous(), "theta", 1 };
5856
CheckedFrom c = "cudnn_affine_grid_generator_forward";
5957
checkContiguous(c, theta);
@@ -75,8 +73,6 @@ Tensor cudnn_affine_grid_generator_backward(
7573
const Tensor& grad_grid_t,
7674
int64_t N, int64_t C, int64_t H, int64_t W)
7775
{
78-
setCuDNNStreamToCurrent();
79-
8076
TensorArg grad_grid{ grad_grid_t.contiguous(), "grad_grid", 1 };
8177
CheckedFrom c = "cudnn_affine_grid_generator_backward";
8278
checkContiguous(c, grad_grid);

aten/src/ATen/native/cudnn/BatchNorm.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
6060
running_mean{ running_mean_t, "running_mean", 4 },
6161
running_var{ running_var_t, "running_var", 5 };
6262
CheckedFrom c = "cudnn_batch_norm";
63-
setCuDNNStreamToCurrent();
6463

6564
checkAllDefined(c, {input, weight, bias});
6665
if (!training) {
@@ -233,7 +232,6 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
233232
save_var{ save_var_t, "save_var", 5 },
234233
reserve{ reserveSpace, "reserve_space", 6 };
235234
CheckedFrom c = "cudnn_batch_norm_backward";
236-
setCuDNNStreamToCurrent();
237235

238236
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
239237
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});

aten/src/ATen/native/cudnn/Conv.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,6 @@ void cudnn_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const T
766766
// responsibility:
767767
// - Things that happen in at::Tensor
768768
// - TensorArg allocation
769-
// - setCuDNNStreamToCurrent
770769
// - Things that happen in TensorArg
771770
// - Check arguments (type, GPU, shape)
772771
//
@@ -918,7 +917,6 @@ Tensor cudnn_convolution(
918917
TensorArg input { input_t, "input", 1 },
919918
weight { weight_t, "weight", 2 },
920919
bias { bias_t, "bias", 3 };
921-
setCuDNNStreamToCurrent();
922920
CheckedFrom c = "cudnn_convolution";
923921
auto output_t = cudnn_convolution_forward(
924922
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
@@ -937,7 +935,6 @@ Tensor cudnn_convolution_transpose_backward_input(
937935
{
938936
TensorArg grad_output { grad_output_t, "grad_output", 1 },
939937
weight { weight_t, "weight", 2 };
940-
setCuDNNStreamToCurrent();
941938
return cudnn_convolution_forward(
942939
"cudnn_convolution_transpose_backward_input",
943940
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
@@ -1062,7 +1059,6 @@ Tensor cudnn_convolution_backward_input(
10621059
{
10631060
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
10641061
weight{ weight_t, "weight", 2 };
1065-
setCuDNNStreamToCurrent();
10661062
return cudnn_convolution_backward_input(
10671063
"cudnn_convolution_backward_input",
10681064
input_size, grad_output, weight,
@@ -1192,7 +1188,6 @@ Tensor cudnn_convolution_backward_weight(
11921188
{
11931189
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
11941190
input{ input_t, "input", 2 };
1195-
setCuDNNStreamToCurrent();
11961191
return cudnn_convolution_backward_weight(
11971192
"cudnn_convolution_backward_weight",
11981193
weight_size, grad_output, input,
@@ -1208,7 +1203,6 @@ Tensor cudnn_convolution_transpose_backward_weight(
12081203
{
12091204
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
12101205
input{ input_t, "input", 2 };
1211-
setCuDNNStreamToCurrent();
12121206
return cudnn_convolution_backward_weight(
12131207
"cudnn_convolution_backward_weight",
12141208
weight_size, input, grad_output,
@@ -1225,7 +1219,6 @@ Tensor cudnn_convolution_backward_bias(
12251219
const Tensor& grad_output_t)
12261220
{
12271221
TensorArg grad_output{ grad_output_t, "grad_output", 1 };
1228-
setCuDNNStreamToCurrent();
12291222

12301223
auto grad_bias_t = at::empty(
12311224
{ grad_output->size(output_channels_dim) }, grad_output->options());

aten/src/ATen/native/cudnn/GridSampler.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ Tensor cudnn_grid_sampler_forward(
6969
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
7070
grid{ grid_t.contiguous(), "grid", 2 };
7171
CheckedFrom c = "cudnn_grid_sampler_forward";
72-
setCuDNNStreamToCurrent();
7372
checkAllSameGPU(c, {input, grid});
7473
checkAllSameType(c, {input, grid});
7574
checkGridSize(c, grid, input);
@@ -108,7 +107,6 @@ std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
108107
grid{ grid_t.contiguous(), "grid", 2 },
109108
grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 };
110109
CheckedFrom c = "cudnn_grid_sampler_backward";
111-
setCuDNNStreamToCurrent();
112110
checkAllSameGPU(c, {input, grad_output, grid});
113111
checkGridSize(c, grid, input);
114112
checkDim(c, input, 4);

aten/src/ATen/native/cudnn/LossCTC.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs_t, const Tens
8686
std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
8787
std::vector<int> target_lengths(target_lengths_.begin(), target_lengths_.end());
8888

89-
setCuDNNStreamToCurrent();
9089
TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
9190
// checked in dispatch:
9291
// assert other conditions for cudnnCTCLoss: all label lengths <= 256

aten/src/ATen/native/cudnn/RNN.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
778778
&reserve_size
779779
));
780780
reserve = at::empty(reserve_size, input.options().dtype(kByte));
781-
setCuDNNStreamToCurrent();
782781
AT_CUDNN_CHECK(cudnnRNNForwardTraining(
783782
handle,
784783
descs.rnn_desc.desc(),
@@ -795,7 +794,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
795794
));
796795
} else { // inference
797796
reserve = at::empty({0}, input.options().dtype(kByte));
798-
setCuDNNStreamToCurrent();
799797
AT_CUDNN_CHECK(cudnnRNNForwardInference(
800798
handle,
801799
descs.rnn_desc.desc(),
@@ -914,7 +912,6 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
914912
));
915913
// TODO: put this in the correct device???
916914
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
917-
setCuDNNStreamToCurrent();
918915
AT_CUDNN_CHECK(cudnnRNNBackwardData(
919916
handle,
920917
descs.rnn_desc.desc(),
@@ -1018,7 +1015,6 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
10181015
&workspace_size
10191016
));
10201017
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
1021-
setCuDNNStreamToCurrent();
10221018
AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
10231019
handle,
10241020
descs.rnn_desc.desc(),

0 commit comments

Comments
 (0)