@@ -106,18 +106,63 @@ template<typename algo_t>
106106struct algorithm_search {
107107};
108108
109+ cudnnStatus_t getWorkspaceSize (cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionFwdAlgo_t algo, size_t * sz){
110+ return cudnnGetConvolutionForwardWorkspaceSize (handle, conv.idesc .desc , conv.wdesc .desc , conv.cdesc .desc , conv.odesc .desc , algo, sz);
111+ }
112+ cudnnStatus_t getWorkspaceSize (cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdDataAlgo_t algo, size_t * sz){
113+ return cudnnGetConvolutionBackwardDataWorkspaceSize (handle, conv.wdesc .desc , conv.odesc .desc , conv.cdesc .desc , conv.idesc .desc , algo, sz);
114+ }
115+ cudnnStatus_t getWorkspaceSize (cudnnHandle_t handle, const Convolution& conv, cudnnConvolutionBwdFilterAlgo_t algo, size_t * sz){
116+ return cudnnGetConvolutionBackwardFilterWorkspaceSize (handle, conv.idesc .desc , conv.odesc .desc , conv.cdesc .desc , conv.wdesc .desc , algo, sz);
117+ }
118+
119+ template <typename algo_t >
120+ size_t getMaxWorkspaceSize (cudnnHandle_t handle, const Convolution& conv, algo_t *algo, int n_algo, THCState* state){
121+ size_t max_ws_size = 0 ;
122+ size_t max_block_size = 0 ;
123+ size_t total_gpu_mem = 0 ;
124+ size_t free_gpu_mem = 0 ;
125+
126+ THCudaCheck (THCudaMemGetInfoCached (state,&free_gpu_mem,&total_gpu_mem,&max_block_size));
127+
128+ for (int i=0 ; i<n_algo; i++) {
129+ cudnnStatus_t err;
130+ size_t sz;
131+ err = getWorkspaceSize (handle, conv, algo[i], &sz);
132+ if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size || sz > max_block_size) continue ;
133+ max_ws_size = sz;
134+ }
135+ return max_ws_size;
136+ }
137+
109138template <>
110139struct algorithm_search <cudnnConvolutionFwdAlgo_t> {
111140 static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
112141 static BenchmarkCache<cudnnConvolutionFwdAlgo_t>& cache () {
113142 return fwd_algos;
114143 }
115144
116- static cudnnConvolutionFwdAlgoPerf_t findAlgorithm (cudnnHandle_t handle, const Convolution& conv) {
145+ static cudnnConvolutionFwdAlgoPerf_t findAlgorithm (THCState* state, cudnnHandle_t handle, const Convolution& conv,
146+ void * in, void * out, void * wght) {
117147 int algoCount;
118148 cudnnConvolutionFwdAlgoPerf_t perfResults;
119- CHECK (cudnnFindConvolutionForwardAlgorithm (handle, conv.idesc .desc ,
120- conv.wdesc .desc , conv.cdesc .desc , conv.odesc .desc , 1 , &algoCount, &perfResults));
149+ cudnnConvolutionFwdAlgo_t algo[] = {
150+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
151+ CUDNN_CONVOLUTION_FWD_ALGO_FFT,
152+ CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
153+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
154+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
155+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
156+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
157+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
158+ };
159+ size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionFwdAlgo_t>(handle,conv,algo,sizeof (algo)/sizeof (algo[0 ]),state);
160+ Workspace ws (state, max_ws_size);
161+
162+ CHECK (cudnnFindConvolutionForwardAlgorithmEx (handle, conv.idesc .desc , in,
163+ conv.wdesc .desc , wght, conv.cdesc .desc , conv.odesc .desc , out, 1 , &algoCount,
164+ &perfResults, ws.data , ws.size ));
165+
121166 return perfResults;
122167 }
123168
@@ -140,11 +185,25 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgo_t> {
140185 return bwd_data_algos;
141186 }
142187
143- static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm (cudnnHandle_t handle, const Convolution& conv) {
188+ static cudnnConvolutionBwdDataAlgoPerf_t findAlgorithm (THCState* state,cudnnHandle_t handle, const Convolution& conv,
189+ void * in, void * out, void * wght) {
144190 int algoCount;
145191 cudnnConvolutionBwdDataAlgoPerf_t perfResults;
146- CHECK (cudnnFindConvolutionBackwardDataAlgorithm (handle, conv.wdesc .desc ,
147- conv.odesc .desc , conv.cdesc .desc , conv.idesc .desc , 1 , &algoCount, &perfResults));
192+ cudnnConvolutionBwdDataAlgo_t algo[] = {
193+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
194+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
195+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
196+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
197+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
198+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
199+ };
200+ size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionBwdDataAlgo_t>(handle,conv,algo,sizeof (algo)/sizeof (algo[0 ]),state);
201+ Workspace ws (state, max_ws_size);
202+
203+ CHECK (cudnnFindConvolutionBackwardDataAlgorithmEx (handle, conv.wdesc .desc , wght,
204+ conv.odesc .desc , out, conv.cdesc .desc , conv.idesc .desc , in, 1 , &algoCount,
205+ &perfResults, ws.data , ws.size ));
206+
148207 return perfResults;
149208 }
150209
@@ -168,11 +227,25 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
168227 return bwd_filter_algos;
169228 }
170229
171- static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm (cudnnHandle_t handle, const Convolution& conv) {
230+ static cudnnConvolutionBwdFilterAlgoPerf_t findAlgorithm (THCState* state, cudnnHandle_t handle, const Convolution& conv,
231+ void * in, void * out, void * wght) {
172232 int algoCount;
173233 cudnnConvolutionBwdFilterAlgoPerf_t perfResults;
174- CHECK (cudnnFindConvolutionBackwardFilterAlgorithm (handle, conv.idesc .desc ,
175- conv.odesc .desc , conv.cdesc .desc , conv.wdesc .desc , 1 , &algoCount, &perfResults));
234+ cudnnConvolutionBwdFilterAlgo_t algo[] = {
235+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
236+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
237+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
238+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
239+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
240+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED
241+ };
242+ size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionBwdFilterAlgo_t>(handle,conv,algo,sizeof (algo)/sizeof (algo[0 ]),state);
243+ Workspace ws (state, max_ws_size);
244+
245+ CHECK (cudnnFindConvolutionBackwardFilterAlgorithmEx (handle, conv.idesc .desc , in,
246+ conv.odesc .desc , out, conv.cdesc .desc , conv.wdesc .desc , wght, 1 , &algoCount,
247+ &perfResults, ws.data , ws.size ));
248+
176249 return perfResults;
177250 }
178251
@@ -191,7 +264,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
191264template <typename algo_t >
192265void findAlgorithm (
193266 THCState* state, cudnnHandle_t handle, const Convolution& conv,
194- bool benchmark, algo_t * algo)
267+ bool benchmark, void * in, void * out, void * wght, algo_t * algo)
195268{
196269 using search = algorithm_search<algo_t >;
197270 auto & cache = search::cache ();
@@ -205,27 +278,28 @@ void findAlgorithm(
205278 return ;
206279 }
207280
208- // findAlgorithm may call cudaFree()
209- std::lock_guard<std::mutex> lock (*THCCachingAllocator_getCudaFreeMutex ());
210281 if (cache.find (conv.params , algo)) {
211282 // re-check cache since another thread may have benchmarked the algorithm
212283 return ;
213284 }
214- auto perfResults = search::findAlgorithm (handle, conv);
285+ auto perfResults = search::findAlgorithm (state, handle, conv, in, out, wght );
215286 if (perfResults.status == CUDNN_STATUS_SUCCESS) {
216287 *algo = perfResults.algo ;
217288 } else {
218289 *algo = search::DEFAULT_ALGO;
219290 }
220291 cache.insert (conv.params , *algo);
292+
293+ THCDeviceAllocator* allocator = THCCachingAllocator_get ();
294+ CUDA_CHECK (allocator->emptyCache (allocator->state ));
221295}
222296
223297template <typename algo_t >
224298Workspace chooseAlgorithm (
225299 THCState* state, cudnnHandle_t handle, const Convolution& conv,
226- bool benchmark, algo_t * algo)
300+ bool benchmark, void * in, void * out, void * wght, algo_t * algo)
227301{
228- findAlgorithm (state, handle, conv, benchmark, algo);
302+ findAlgorithm (state, handle, conv, benchmark, in, out, wght, algo);
229303
230304 using search = algorithm_search<algo_t >;
231305 size_t workspace_size;
@@ -307,7 +381,11 @@ void cudnn_convolution_forward(
307381 int groups = info->groups ;
308382
309383 cudnnConvolutionFwdAlgo_t fwdAlg;
310- Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, &fwdAlg);
384+ void * in = tensorPointer (dataType, input, 0 , groups, 1 );
385+ void * out = tensorPointer (dataType, output, 0 , groups, 1 );
386+ void * wght = tensorPointer (dataType, weight, 0 , groups, 0 );
387+
388+ Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, in, out, wght, &fwdAlg);
311389
312390 Constant one (dataType, 1 );
313391 Constant zero (dataType, 0 );
@@ -353,7 +431,10 @@ void cudnn_convolution_backward_data(
353431 int groups = info->params .groups ;
354432
355433 cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
356- Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, &bwdDataAlg);
434+ void * in = tensorPointer (dataType, gradInput, 0 , groups, 1 );
435+ void * out = tensorPointer (dataType, gradOutput, 0 , groups, 1 );
436+ void * wght = tensorPointer (dataType, weight, 0 , groups, 0 );
437+ Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, in, out, wght, &bwdDataAlg);
357438
358439 Constant one (dataType, 1 );
359440 Constant zero (dataType, 0 );
@@ -378,7 +459,10 @@ void cudnn_convolution_backward_filter(
378459 int groups = info->params .groups ;
379460
380461 cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
381- Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, &bwdFilterAlg);
462+ void * in = tensorPointer (dataType, input, 0 , groups, 1 );
463+ void * out = tensorPointer (dataType, gradOutput, 0 , groups, 1 );
464+ void * wght = tensorPointer (dataType, gradWeight, 0 , groups, 0 );
465+ Workspace workspace = chooseAlgorithm (state, handle, *info, benchmark, in, out, wght, &bwdFilterAlg);
382466
383467 Constant one (dataType, 1 );
384468 Constant zero (dataType, 0 );
0 commit comments