Skip to content

Commit 9aadbbe

Browse files
committed
merge form develop
2 parents 965caf1 + 8b87d5e commit 9aadbbe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+602
-62
lines changed

paddle/fluid/framework/new_executor/data_transfer.cc

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
6262
return is_transferred;
6363
}
6464

65+
void DataTranferHelper::RunAndConstructShareNode(
66+
const std::string& src_var_name, const std::string& dst_var_name,
67+
std::vector<OpFuncNode>* op_func_nodes) {
68+
VariableNameMap in_name_map = {{"X", {src_var_name}}};
69+
VariableNameMap out_name_map = {{"Out", {dst_var_name}}};
70+
AttributeMap attr_map;
71+
72+
std::string op_type("share_data");
73+
auto& op_info = OpInfoMap::Instance().Get(op_type);
74+
auto op = std::shared_ptr<OperatorBase>(
75+
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
76+
77+
VLOG(3) << string::Sprintf("Insert %s with %s -> %s.", op_type, src_var_name,
78+
dst_var_name);
79+
80+
RunAndConstructOpFuncNode(op, src_var_name, dst_var_name, op_func_nodes);
81+
}
82+
6583
void DataTranferHelper::RunAndConstructOpFuncNode(
6684
const std::shared_ptr<OperatorBase>& op, const std::string& var_name,
6785
const std::string& new_var_name,
@@ -133,7 +151,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
133151
VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
134152
AttributeMap attr_map = {{"dst_layout", static_cast<int>(out_layout)}};
135153

136-
// 3. Create transfer_op
154+
// 3. Create transfer_layout_op
137155
std::string op_type("transfer_layout");
138156
auto& op_info = OpInfoMap::Instance().Get(op_type);
139157
auto op = std::shared_ptr<OperatorBase>(
@@ -154,9 +172,10 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
154172
*new_var_name =
155173
var_name + "_dtype_" + std::to_string(var_scope->VarSize() + 1);
156174
auto* ptr = local_scope->Var(new_var_name);
157-
175+
var_scope->SetVarDesc(var_name, nullptr);
158176
auto var_type = var_scope->Var(var_name)->Type();
159177
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
178+
160179
VLOG(3) << "Create Variable " << *new_var_name
161180
<< " locally, which pointer is " << ptr << "Variable Type "
162181
<< var_type;
@@ -171,7 +190,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
171190
// NOTE(Aurelius84): In whice case use_mkldnn = true?
172191
attr_map["use_mkldnn"] = false;
173192

174-
// 3. Create transfer_op
193+
// 3. Create transfer_dtype_op
175194
std::string op_type("transfer_dtype");
176195
auto& op_info = OpInfoMap::Instance().Get(op_type);
177196
auto op = std::shared_ptr<OperatorBase>(
@@ -209,7 +228,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
209228
: platform::is_gpu_place(dst_place) ? 1 : -1;
210229
AttributeMap attr_map = {{"dst_place_type", dst_place_type}};
211230

212-
// 3. Create transfer_op
231+
// 3. Create memcpy_d2h_op or memcpy_h2d_op
213232
std::string op_type = get_memcpy_type(src_place, dst_place);
214233
auto& op_info = OpInfoMap::Instance().Get(op_type);
215234
auto op = std::shared_ptr<OperatorBase>(
@@ -303,6 +322,95 @@ std::string get_memcpy_type(const platform::Place& src_place,
303322
}
304323
}
305324

325+
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
326+
const platform::Place& place,
327+
const VariableNameMap& out_names,
328+
VariableValueMap* out_vars,
329+
VariableScope* var_scope,
330+
std::vector<OpFuncNode>* op_func_nodes,
331+
framework::Scope* local_scope) {
332+
DataTranferHelper data_transfer_helper(place, var_scope);
333+
for (auto& var_name_item : out_names) {
334+
std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
335+
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
336+
// 1. find grad_var & check whether is complex tensor
337+
auto var_name = var_name_item.second[i];
338+
auto orig_var_name = framework::GradOriginalVarName(var_name);
339+
// only focus on gradient var
340+
if (var_name == orig_var_name) {
341+
VLOG(3) << "skip " << var_name << " with same name as "
342+
<< orig_var_name;
343+
continue;
344+
}
345+
auto* grad_var = vars[i];
346+
// skip nullptr var
347+
if (grad_var == nullptr) {
348+
VLOG(3) << "skip grad_var with nullptr";
349+
continue;
350+
}
351+
// don't process LoDTensorArray temporarily,
352+
// add support if necessary for complex number calculations in the future
353+
if (!framework::VarIsTensor(*grad_var)) {
354+
VLOG(3) << "skip grad_var with LoDTensorArray type";
355+
continue;
356+
}
357+
auto* grad_tensor =
358+
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(grad_var);
359+
// skip nullptr tensor
360+
if (grad_tensor == nullptr || !grad_tensor->IsInitialized()) {
361+
VLOG(3) << "skip with grad_tensor not IsInitialized";
362+
continue;
363+
}
364+
// only focus on complex dtype now
365+
auto src_type = grad_tensor->type();
366+
if (!framework::IsComplexType(src_type)) {
367+
VLOG(3) << "skip grad_tensor with not complexType";
368+
continue;
369+
}
370+
371+
// 2. find forward var & check whether need to cast
372+
auto* var = var_scope->FindVar(orig_var_name);
373+
// if forward var not exists, do nothing
374+
if (var == nullptr) {
375+
VLOG(3) << "skip " << orig_var_name << " with not found in var_scope";
376+
continue;
377+
}
378+
if (!framework::VarIsTensor(*var)) {
379+
VLOG(3) << "skip " << orig_var_name << " with LoDTensorArray.";
380+
continue;
381+
}
382+
const auto* tensor =
383+
framework::GetLoDTensorOrSelectedRowsValueFromVar(*var);
384+
PADDLE_ENFORCE_NOT_NULL(
385+
tensor,
386+
platform::errors::Unavailable(
387+
"Forward tensor is nullptr when handle complex data to real."));
388+
// only need record type, the allocation may have been released
389+
auto dst_type = tensor->saved_type();
390+
// only focus on real dtype and need casting
391+
if (framework::IsComplexType(dst_type)) {
392+
continue;
393+
}
394+
395+
// 3. cast complex grad to real grad inplacely
396+
VLOG(3) << "Transform " << framework::DataTypeToString(src_type)
397+
<< " var `" << var_name << "` to "
398+
<< framework::DataTypeToString(dst_type)
399+
<< " real var in static graph.";
400+
401+
// NOTE(Aurelius84): Consider to define a complex2real op to deal this
402+
// case.
403+
std::string new_var_name;
404+
auto op = TransferDtype(var_name, &new_var_name, src_type, dst_type,
405+
var_scope, local_scope);
406+
data_transfer_helper.RunAndConstructOpFuncNode(op, var_name, new_var_name,
407+
op_func_nodes);
408+
data_transfer_helper.RunAndConstructShareNode(new_var_name, var_name,
409+
op_func_nodes);
410+
}
411+
}
412+
}
413+
306414
} // namespace interpreter
307415
} // namespace framework
308416
} // namespace paddle

paddle/fluid/framework/new_executor/data_transfer.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@ class DataTranferHelper {
3737
const std::string& var_name, std::string* new_var_name,
3838
std::vector<OpFuncNode>* new_op_func_nodes, bool use_local_scope);
3939

40-
private:
41-
platform::Place place_;
42-
VariableScope* var_scope_;
40+
void RunAndConstructShareNode(const std::string& src_var_name,
41+
const std::string& dst_var_name,
42+
std::vector<OpFuncNode>* op_func_nodes);
4343

4444
void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op,
4545
const std::string& var_name,
4646
const std::string& new_var_name,
4747
std::vector<OpFuncNode>* op_func_nodes);
48+
49+
private:
50+
platform::Place place_;
51+
VariableScope* var_scope_;
4852
};
4953

5054
void ApplyDataTransform(const OpKernelType& expected_kernel_key,
@@ -54,6 +58,14 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
5458
std::vector<OpFuncNode>* op_func_nodes,
5559
bool use_local_scope = true);
5660

61+
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
62+
const platform::Place& place,
63+
const VariableNameMap& out_names,
64+
VariableValueMap* out_vars,
65+
VariableScope* var_scope,
66+
std::vector<OpFuncNode>* op_func_nodes,
67+
framework::Scope* local_scope);
68+
5769
std::string get_memcpy_type(const platform::Place& src_place,
5870
const platform::Place& dst_place);
5971

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ paddle::framework::FetchList InterpreterCore::Run(
9090

9191
// return Fetch Tensors
9292
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
93-
return *(fetch_var->GetMutable<framework::FetchList>());
93+
return std::move(*fetch_var->GetMutable<framework::FetchList>());
9494
}
9595

9696
paddle::framework::FetchList InterpreterCore::Run(
@@ -124,7 +124,7 @@ paddle::framework::FetchList InterpreterCore::Run(
124124

125125
// return Fetch Tensors
126126
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
127-
return *(fetch_var->GetMutable<framework::FetchList>());
127+
return std::move(*fetch_var->GetMutable<framework::FetchList>());
128128
}
129129

130130
void InterpreterCore::BuildOperatorDependences() {

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,14 @@ void build_op_func_list(const platform::Place& place,
328328
->GetExpectedKernelType(
329329
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
330330

331-
// consider device_guard()
332-
apply_device_guard(
333-
op, place,
334-
&expected_kernel_key); // change device by the device_guard()
331+
// change device by the device_guard()
332+
apply_device_guard(op, place, &expected_kernel_key);
335333
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
336334

337335
// step 3. apply data transforms and insert data transfer ops
338336
VariableValueMap& ins_map_temp = runtime_context.inputs;
339-
std::vector<OpFuncNode> new_op_func_nodes;
340337
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
341-
&op_func_node, &new_op_func_nodes, use_local_scope);
342-
for (auto& item : new_op_func_nodes) {
343-
vec_func_list->emplace_back(std::move(item));
344-
}
338+
&op_func_node, vec_func_list, use_local_scope);
345339
// step 4. Run op kernel
346340
VLOG(3) << op->Type()
347341
<< " : expected_kernel_key : " << expected_kernel_key;
@@ -370,6 +364,14 @@ void build_op_func_list(const platform::Place& place,
370364

371365
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
372366
op_func_node.kernel_func_(exec_ctx);
367+
368+
// post-process grad_op.outputs if need cast complex grad into real grad.
369+
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
370+
if (framework::IsComplexType(expected_kernel_key.data_type_)) {
371+
interpreter::HandleComplexGradToRealGrad(
372+
op_func_node, place, outputs_names, &runtime_context.outputs,
373+
var_scope, vec_func_list, local_scope);
374+
}
373375
}
374376

375377
vec_func_list->emplace_back(op_func_node);

paddle/fluid/framework/new_executor/interpretercore_util.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ namespace framework {
5151
namespace interpreter {
5252

5353
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
54-
static constexpr char kFetchVarName[] = "fetch";
5554

5655
class AsyncWorkQueue {
5756
public:

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ class Instruction {
374374
namespace interpreter {
375375
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
376376
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
377+
static constexpr char kFetchVarName[] = "fetch";
377378

378379
static bool IsMemcpyH2D(const Instruction& instr) {
379380
return instr.OpBase()->Type() == kMemcpyH2D;

paddle/fluid/framework/operator.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,6 @@ void OperatorBase::GenerateTemporaryNames() {
479479
}
480480
}
481481

482-
static bool VarIsTensor(const Variable& var) {
483-
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
484-
}
485-
486482
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
487483
if (var.IsType<LoDTensor>()) {
488484
return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));

paddle/fluid/framework/operator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
114114
}
115115
}
116116

117+
inline bool VarIsTensor(const Variable& var) {
118+
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
119+
}
120+
117121
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
118122
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
119123

paddle/fluid/imperative/amp_auto_cast.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
261261
dst_type = framework::proto::VarType::FP32;
262262
}
263263
for (auto& pair : new_ins) {
264+
// NOTE: The run_program OP only has FP32 kernel. In dy2stat pure fp16
265+
// training, we have correctly cast the inputs of run_program OP before,
266+
// so here should avoid casting for run_program OP.
267+
if (op_type == "run_program") {
268+
continue;
269+
}
270+
264271
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
265272
op_type == "sync_batch_norm") &&
266273
pair.first != "X") {

paddle/pten/api/lib/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ template <typename T>
283283
Tensor Tensor::copy_to(const PlaceType &target_place) const {
284284
LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
285285
"2.3, and will be removed in version 2.4, please use "
286-
"`copy_to` method without template argumentinstead. "
286+
"`copy_to` method without template argument instead. "
287287
"reason: copying a Tensor to another device does not need "
288288
"to specify the data type template argument.";
289289
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);

0 commit comments

Comments
 (0)