Skip to content

Commit 568c5c9

Browse files
aromnvidiasoumith
authored andcommitted
substitute cudnnFind* functions with cudnnFind*Ex
1 parent 501467d commit 568c5c9

File tree

1 file changed

+102
-18
lines changed

1 file changed

+102
-18
lines changed

torch/csrc/cudnn/Conv.cpp

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,63 @@ template<typename algo_t>
106106
struct algorithm_search {
107107
};
108108

109+
cudnnStatus_t getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t algo, size_t* sz){
110+
return cudnnGetConvolutionForwardWorkspaceSize(handle, conv.idesc.desc, conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, algo, sz);
111+
}
112+
cudnnStatus_t getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t algo, size_t* sz){
113+
return cudnnGetConvolutionBackwardDataWorkspaceSize(handle, conv.wdesc.desc, conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, algo, sz);
114+
}
115+
cudnnStatus_t getWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz){
116+
return cudnnGetConvolutionBackwardFilterWorkspaceSize(handle, conv.idesc.desc, conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, algo, sz);
117+
}
118+
119+
template<typename algo_t>
120+
size_t getMaxWorkspaceSize(cudnnHandle_t handle, const Convolution& conv, algo_t *algo, int n_algo, THCState* state){
121+
size_t max_ws_size = 0;
122+
size_t max_block_size = 0;
123+
size_t total_gpu_mem = 0;
124+
size_t free_gpu_mem = 0;
125+
126+
THCudaCheck(THCudaMemGetInfoCached(state,&free_gpu_mem,&total_gpu_mem,&max_block_size));
127+
128+
for(int i=0; i<n_algo; i++) {
129+
cudnnStatus_t err;
130+
size_t sz;
131+
err = getWorkspaceSize(handle, conv, algo[i], &sz);
132+
if(CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > max_block_size) continue;
133+
max_ws_size = sz;
134+
}
135+
return max_ws_size;
136+
}
137+
109138
template<>
110139
struct algorithm_search<cudnnConvolutionFwdAlgo_t> {
111140
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
112141
static BenchmarkCache<cudnnConvolutionFwdAlgo_t>& cache() {
113142
return fwd_algos;
114143
}
115144

116-
static cudnnConvolutionFwdAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
145+
static cudnnConvolutionFwdAlgoPerf_t findAlgorithm(THCState* state, cudnnHandle_t handle, const Convolution& conv,
146+
void* in, void* out, void* wght) {
117147
int algoCount;
118148
cudnnConvolutionFwdAlgoPerf_t perfResults;
119-
CHECK(cudnnFindConvolutionForwardAlgorithm(handle, conv.idesc.desc,
120-
conv.wdesc.desc, conv.cdesc.desc, conv.odesc.desc, 1, &algoCount, &perfResults));
149+
cudnnConvolutionFwdAlgo_t algo[] = {
150+
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
151+
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
152+
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
153+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
154+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
155+
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
156+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
157+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
158+
};
159+
size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionFwdAlgo_t>(handle,conv,algo,sizeof(algo)/sizeof(algo[0]),state);
160+
Workspace ws(state, max_ws_size);
161+
162+
CHECK(cudnnFindConvolutionForwardAlgorithmEx(handle, conv.idesc.desc, in,
163+
conv.wdesc.desc, wght, conv.cdesc.desc, conv.odesc.desc, out, 1, &algoCount,
164+
&perfResults, ws.data, ws.size));
165+
121166
return perfResults;
122167
}
123168

@@ -140,11 +185,25 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> {
140185
return bwd_data_algos;
141186
}
142187

143-
static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
188+
static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm(THCState* state,cudnnHandle_t handle, const Convolution& conv,
189+
void* in, void* out, void* wght) {
144190
int algoCount;
145191
cudnnConvolutionBwdDataAlgoPerf_t perfResults;
146-
CHECK(cudnnFindConvolutionBackwardDataAlgorithm(handle, conv.wdesc.desc,
147-
conv.odesc.desc, conv.cdesc.desc, conv.idesc.desc, 1, &algoCount, &perfResults));
192+
cudnnConvolutionBwdDataAlgo_t algo[] = {
193+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
194+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
195+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
196+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
197+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
198+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
199+
};
200+
size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionBwdDataAlgo_t>(handle,conv,algo,sizeof(algo)/sizeof(algo[0]),state);
201+
Workspace ws(state, max_ws_size);
202+
203+
CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(handle, conv.wdesc.desc, wght,
204+
conv.odesc.desc, out, conv.cdesc.desc, conv.idesc.desc, in, 1, &algoCount,
205+
&perfResults, ws.data, ws.size));
206+
148207
return perfResults;
149208
}
150209

@@ -168,11 +227,25 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
168227
return bwd_filter_algos;
169228
}
170229

171-
static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm(cudnnHandle_t handle, const Convolution& conv) {
230+
static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm(THCState* state, cudnnHandle_t handle, const Convolution& conv,
231+
void* in, void* out, void* wght) {
172232
int algoCount;
173233
cudnnConvolutionBwdFilterAlgoPerf_t perfResults;
174-
CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(handle, conv.idesc.desc,
175-
conv.odesc.desc, conv.cdesc.desc, conv.wdesc.desc, 1, &algoCount, &perfResults));
234+
cudnnConvolutionBwdFilterAlgo_t algo[] = {
235+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
236+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
237+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
238+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
239+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
240+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED
241+
};
242+
size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionBwdFilterAlgo_t>(handle,conv,algo,sizeof(algo)/sizeof(algo[0]),state);
243+
Workspace ws(state, max_ws_size);
244+
245+
CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(handle, conv.idesc.desc, in,
246+
conv.odesc.desc, out, conv.cdesc.desc, conv.wdesc.desc, wght, 1, &algoCount,
247+
&perfResults, ws.data, ws.size));
248+
176249
return perfResults;
177250
}
178251

@@ -191,7 +264,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
191264
template<typename algo_t>
192265
void findAlgorithm(
193266
THCState* state, cudnnHandle_t handle, const Convolution& conv,
194-
bool benchmark, algo_t* algo)
267+
bool benchmark, void* in, void* out, void* wght, algo_t* algo)
195268
{
196269
using search = algorithm_search<algo_t>;
197270
auto& cache = search::cache();
@@ -205,27 +278,28 @@ void findAlgorithm(
205278
return;
206279
}
207280

208-
// findAlgorithm may call cudaFree()
209-
std::lock_guard<std::mutex> lock(*THCCachingAllocator_getCudaFreeMutex());
210281
if (cache.find(conv.params, algo)) {
211282
// re-check cache since another thread may have benchmarked the algorithm
212283
return;
213284
}
214-
auto perfResults = search::findAlgorithm(handle, conv);
285+
auto perfResults = search::findAlgorithm(state, handle, conv, in, out, wght);
215286
if (perfResults.status == CUDNN_STATUS_SUCCESS) {
216287
*algo = perfResults.algo;
217288
} else {
218289
*algo = search::DEFAULT_ALGO;
219290
}
220291
cache.insert(conv.params, *algo);
292+
293+
THCDeviceAllocator* allocator = THCCachingAllocator_get();
294+
CUDA_CHECK(allocator->emptyCache(allocator->state));
221295
}
222296

223297
template<typename algo_t>
224298
Workspace chooseAlgorithm(
225299
THCState* state, cudnnHandle_t handle, const Convolution& conv,
226-
bool benchmark, algo_t* algo)
300+
bool benchmark, void* in, void* out, void* wght, algo_t* algo)
227301
{
228-
findAlgorithm(state, handle, conv, benchmark, algo);
302+
findAlgorithm(state, handle, conv, benchmark, in, out, wght, algo);
229303

230304
using search = algorithm_search<algo_t>;
231305
size_t workspace_size;
@@ -307,7 +381,11 @@ void cudnn_convolution_forward(
307381
int groups = info->groups;
308382

309383
cudnnConvolutionFwdAlgo_t fwdAlg;
310-
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &fwdAlg);
384+
void* in = tensorPointer(dataType, input, 0, groups, 1);
385+
void* out = tensorPointer(dataType, output, 0, groups, 1);
386+
void* wght = tensorPointer(dataType, weight, 0, groups, 0);
387+
388+
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, in, out, wght, &fwdAlg);
311389

312390
Constant one(dataType, 1);
313391
Constant zero(dataType, 0);
@@ -353,7 +431,10 @@ void cudnn_convolution_backward_data(
353431
int groups = info->params.groups;
354432

355433
cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
356-
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdDataAlg);
434+
void* in = tensorPointer(dataType, gradInput, 0, groups, 1);
435+
void* out = tensorPointer(dataType, gradOutput, 0, groups, 1);
436+
void* wght = tensorPointer(dataType, weight, 0, groups, 0);
437+
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, in, out, wght, &bwdDataAlg);
357438

358439
Constant one(dataType, 1);
359440
Constant zero(dataType, 0);
@@ -378,7 +459,10 @@ void cudnn_convolution_backward_filter(
378459
int groups = info->params.groups;
379460

380461
cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
381-
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, &bwdFilterAlg);
462+
void* in = tensorPointer(dataType, input, 0, groups, 1);
463+
void* out = tensorPointer(dataType, gradOutput, 0, groups, 1);
464+
void* wght = tensorPointer(dataType, gradWeight, 0, groups, 0);
465+
Workspace workspace = chooseAlgorithm(state, handle, *info, benchmark, in, out, wght, &bwdFilterAlg);
382466

383467
Constant one(dataType, 1);
384468
Constant zero(dataType, 0);

0 commit comments

Comments
 (0)