@@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
6262 return is_transferred;
6363}
6464
65+ void DataTranferHelper::RunAndConstructShareNode (
66+ const std::string& src_var_name, const std::string& dst_var_name,
67+ std::vector<OpFuncNode>* op_func_nodes) {
68+ VariableNameMap in_name_map = {{" X" , {src_var_name}}};
69+ VariableNameMap out_name_map = {{" Out" , {dst_var_name}}};
70+ AttributeMap attr_map;
71+
72+ std::string op_type (" share_data" );
73+ auto & op_info = OpInfoMap::Instance ().Get (op_type);
74+ auto op = std::shared_ptr<OperatorBase>(
75+ op_info.Creator ()(op_type, in_name_map, out_name_map, attr_map));
76+
77+ VLOG (3 ) << string::Sprintf (" Insert %s with %s -> %s." , op_type, src_var_name,
78+ dst_var_name);
79+
80+ RunAndConstructOpFuncNode (op, src_var_name, dst_var_name, op_func_nodes);
81+ }
82+
6583void DataTranferHelper::RunAndConstructOpFuncNode (
6684 const std::shared_ptr<OperatorBase>& op, const std::string& var_name,
6785 const std::string& new_var_name,
@@ -133,7 +151,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
133151 VariableNameMap out_name_map = {{" Out" , {*new_var_name}}};
134152 AttributeMap attr_map = {{" dst_layout" , static_cast <int >(out_layout)}};
135153
136- // 3. Create transfer_op
154+ // 3. Create transfer_layout_op
137155 std::string op_type (" transfer_layout" );
138156 auto & op_info = OpInfoMap::Instance ().Get (op_type);
139157 auto op = std::shared_ptr<OperatorBase>(
@@ -154,9 +172,10 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
154172 *new_var_name =
155173 var_name + " _dtype_" + std::to_string (var_scope->VarSize () + 1 );
156174 auto * ptr = local_scope->Var (new_var_name);
157-
175+ var_scope-> SetVarDesc (var_name, nullptr );
158176 auto var_type = var_scope->Var (var_name)->Type ();
159177 InitializeVariable (ptr, static_cast <proto::VarType::Type>(var_type));
178+
160179 VLOG (3 ) << " Create Variable " << *new_var_name
161180 << " locally, which pointer is " << ptr << " Variable Type "
162181 << var_type;
@@ -171,7 +190,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
171190 // NOTE(Aurelius84): In whice case use_mkldnn = true?
172191 attr_map[" use_mkldnn" ] = false ;
173192
174- // 3. Create transfer_op
193+ // 3. Create transfer_dtype_op
175194 std::string op_type (" transfer_dtype" );
176195 auto & op_info = OpInfoMap::Instance ().Get (op_type);
177196 auto op = std::shared_ptr<OperatorBase>(
@@ -209,7 +228,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
209228 : platform::is_gpu_place (dst_place) ? 1 : -1 ;
210229 AttributeMap attr_map = {{" dst_place_type" , dst_place_type}};
211230
212- // 3. Create transfer_op
231+ // 3. Create memcpy_d2h_op or memcpy_h2d_op
213232 std::string op_type = get_memcpy_type (src_place, dst_place);
214233 auto & op_info = OpInfoMap::Instance ().Get (op_type);
215234 auto op = std::shared_ptr<OperatorBase>(
@@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place,
303322 }
304323}
305324
325+ void HandleComplexGradToRealGrad (const OpFuncNode& op_func_node,
326+ const platform::Place& place,
327+ const VariableNameMap& out_names,
328+ VariableValueMap* out_vars,
329+ VariableScope* var_scope,
330+ std::vector<OpFuncNode>* op_func_nodes,
331+ framework::Scope* local_scope) {
332+ DataTranferHelper data_transfer_helper (place, var_scope);
333+ for (auto & var_name_item : out_names) {
334+ std::vector<Variable*>& vars = out_vars->at (var_name_item.first );
335+ for (size_t i = 0 ; i < var_name_item.second .size (); ++i) {
336+ // 1. find grad_var & check whether is complex tensor
337+ auto var_name = var_name_item.second [i];
338+ auto orig_var_name = framework::GradOriginalVarName (var_name);
339+ // only focus on gradient var
340+ if (var_name == orig_var_name) {
341+ VLOG (3 ) << " skip " << var_name << " with same name as "
342+ << orig_var_name;
343+ continue ;
344+ }
345+ auto * grad_var = vars[i];
346+ // skip nullptr var
347+ if (grad_var == nullptr ) {
348+ VLOG (3 ) << " skip grad_var with nullptr" ;
349+ continue ;
350+ }
351+ // don't process LoDTensorArray temporarily,
352+ // add support if necessary for complex number calculations in the future
353+ if (!framework::VarIsTensor (*grad_var)) {
354+ VLOG (3 ) << " skip grad_var with LoDTensorArray type" ;
355+ continue ;
356+ }
357+ auto * grad_tensor =
358+ framework::GetMutableLoDTensorOrSelectedRowsValueFromVar (grad_var);
359+ // skip nullptr tensor
360+ if (grad_tensor == nullptr || !grad_tensor->IsInitialized ()) {
361+ VLOG (3 ) << " skip with grad_tensor not IsInitialized" ;
362+ continue ;
363+ }
364+ // only focus on complex dtype now
365+ auto src_type = grad_tensor->type ();
366+ if (!framework::IsComplexType (src_type)) {
367+ VLOG (3 ) << " skip grad_tensor with not complexType" ;
368+ continue ;
369+ }
370+
371+ // 2. find forward var & check whether need to cast
372+ auto * var = var_scope->FindVar (orig_var_name);
373+ // if forward var not exists, do nothing
374+ if (var == nullptr ) {
375+ VLOG (3 ) << " skip " << orig_var_name << " with not found in var_scope" ;
376+ continue ;
377+ }
378+ if (!framework::VarIsTensor (*var)) {
379+ VLOG (3 ) << " skip " << orig_var_name << " with LoDTensorArray." ;
380+ continue ;
381+ }
382+ const auto * tensor =
383+ framework::GetLoDTensorOrSelectedRowsValueFromVar (*var);
384+ PADDLE_ENFORCE_NOT_NULL (
385+ tensor,
386+ platform::errors::Unavailable (
387+ " Forward tensor is nullptr when handle complex data to real." ));
388+ // only need record type, the allocation may have been released
389+ auto dst_type = tensor->saved_type ();
390+ // only focus on real dtype and need casting
391+ if (framework::IsComplexType (dst_type)) {
392+ continue ;
393+ }
394+
395+ // 3. cast complex grad to real grad inplacely
396+ VLOG (3 ) << " Transform " << framework::DataTypeToString (src_type)
397+ << " var `" << var_name << " ` to "
398+ << framework::DataTypeToString (dst_type)
399+ << " real var in static graph." ;
400+
401+ // NOTE(Aurelius84): Consider to define a complex2real op to deal this
402+ // case.
403+ std::string new_var_name;
404+ auto op = TransferDtype (var_name, &new_var_name, src_type, dst_type,
405+ var_scope, local_scope);
406+ data_transfer_helper.RunAndConstructOpFuncNode (op, var_name, new_var_name,
407+ op_func_nodes);
408+ data_transfer_helper.RunAndConstructShareNode (new_var_name, var_name,
409+ op_func_nodes);
410+ }
411+ }
412+ }
413+
306414} // namespace interpreter
307415} // namespace framework
308416} // namespace paddle
0 commit comments