@@ -238,7 +238,8 @@ def parse_kernel(self, kernel_config):
238238 'param' : None ,
239239 'backend' : None ,
240240 'layout' : None ,
241- 'data_type' : None
241+ 'data_type' : None ,
242+ 'use_cudnn' : 'false'
242243 }
243244 if 'backend' in kernel_config and len (kernel_config ['backend' ]) > 0 :
244245 kernel ['backend' ] = kernel_config ['backend' ]
@@ -248,6 +249,10 @@ def parse_kernel(self, kernel_config):
248249 kernel ['data_type' ] = kernel_config ['data_type' ]
249250 if 'param' in kernel_config :
250251 kernel ['param' ] = kernel_config ['param' ]
252+ if 'use_cudnn' in kernel_config :
253+ kernel ['use_cudnn' ] = kernel_config ['use_cudnn' ]
254+ if isinstance (kernel ['use_cudnn' ], bool ):
255+ kernel ['use_cudnn' ] = str (kernel ['use_cudnn' ]).lower ()
251256 kernel ['func' ] = [
252257 kernel_fn .strip () for kernel_fn in kernel_config ['func' ].split (',' )
253258 ]
@@ -709,10 +714,12 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False):
709714 outputs_args , kernel_output_names , output_create = self .gene_output (
710715 self .outputs ['types' ], 'SetKernelOutput' , code_indent , inplace_flag )
711716 api_func_name = self .get_api_func_name () + ('_' if inplace_flag else '' )
717+ cudnn_args = '' if self .kernel [
718+ 'use_cudnn' ] == 'false' else ', ' + self .kernel ['use_cudnn' ]
712719 return f"""
713720{ code_indent } VLOG(6) << "{ self .api } API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
714721{ code_indent } const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
715- { code_indent } "{ self .kernel ['func' ][0 ]} ", {{kernel_backend, kernel_layout, kernel_data_type}});
722+ { code_indent } "{ self .kernel ['func' ][0 ]} ", {{kernel_backend, kernel_layout, kernel_data_type}}{ cudnn_args } );
716723{ code_indent } VLOG(6) << "{ self .api } API kernel: " << kernel;
717724
718725{ code_indent } auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
0 commit comments