Skip to content

Commit 64b6dc6

Browse files
committed
Refactored trace_op logic for eager mode
1 parent 7112486 commit 64b6dc6

File tree

5 files changed

+93
-22
lines changed

5 files changed

+93
-22
lines changed

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ std::unordered_map<std::string, std::vector<std::string>>
3434
core_ops_returns_info = {};
3535
std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info =
3636
{};
37+
std::unordered_map<std::string, std::vector<std::string>>
38+
core_ops_args_type_info = {};
3739

3840
/* --- Static maps to handle corner cases --- */
3941
static std::unordered_map<std::string, paddle::framework::AttributeMap>
@@ -1120,7 +1122,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
11201122
std::string generated_function_body = "";
11211123
std::string dygraph_function_args_str = "";
11221124
core_ops_args_info[op_type] = {};
1125+
core_ops_args_type_info[op_type] = {};
11231126
core_ops_args_info[op_type].resize(in_vars.size());
1127+
core_ops_args_type_info[op_type].resize(in_vars.size());
11241128

11251129
/* ------ Dygraph forward function generation ------ */
11261130
generated_function_body += " // Dygraph Forward Pass\n";
@@ -1138,10 +1142,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
11381142
"const std::vector<egr::EagerTensor>& %s";
11391143
input_args_str_list[input_position] =
11401144
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
1145+
1146+
core_ops_args_type_info[op_type][input_position] = "list";
11411147
} else {
11421148
const char* FWD_INS_ARG_TEMPLATE = "const egr::EagerTensor& %s";
11431149
input_args_str_list[input_position] =
11441150
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
1151+
1152+
core_ops_args_type_info[op_type][input_position] = "tensor";
11451153
}
11461154
core_ops_args_info[op_type][input_position] = input_name;
11471155

@@ -1210,11 +1218,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12101218
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
12111219
dygraph_function_args_str += arg_str;
12121220

1221+
core_ops_args_type_info[op_type].push_back("list");
12131222
} else {
12141223
const char* FWD_NUM_ARG_TEMPLATE = ", egr::EagerTensor* %s";
12151224
std::string arg_str =
12161225
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
12171226
dygraph_function_args_str += arg_str;
1227+
1228+
core_ops_args_type_info[op_type].push_back("tensor");
12181229
}
12191230
const char* FWD_OUTS_CONTENT_TEMPLATE =
12201231
"{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },";
@@ -1236,6 +1247,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
12361247
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
12371248
output_name, outnum);
12381249
core_ops_args_info[op_type].push_back(outnum);
1250+
core_ops_args_type_info[op_type].push_back("int");
12391251
} else {
12401252
const char* FWD_OUTS_CONTENT_TEMPLATE =
12411253
"{ \"%s\", "
@@ -1840,6 +1852,9 @@ static std::string GenerateDygraphHFileIncludes() {
18401852
dygraph_forward_api_includes_str +=
18411853
"extern std::unordered_map<std::string, std::vector<std::string>> "
18421854
"core_ops_args_info;\n";
1855+
dygraph_forward_api_includes_str +=
1856+
"extern std::unordered_map<std::string, std::vector<std::string>> "
1857+
"core_ops_args_type_info;\n";
18431858
dygraph_forward_api_includes_str +=
18441859
"extern std::unordered_map<std::string, std::vector<std::string>> "
18451860
"core_ops_returns_info;\n\n";
@@ -1936,16 +1951,20 @@ static std::string GenerateCoreOpsReturnsInfo() {
19361951
"std::unordered_map<std::string, std::vector<std::string>> "
19371952
"core_ops_args_info = { %s };\n"
19381953
"std::unordered_map<std::string, std::vector<std::string>> "
1954+
"core_ops_args_type_info = { %s };\n"
1955+
"std::unordered_map<std::string, std::vector<std::string>> "
19391956
"core_ops_returns_info = { %s };\n";
19401957

19411958
std::string core_ops_args_info_init_str =
19421959
ConvertCoreOpsInfosToString(core_ops_args_info);
1960+
std::string core_ops_args_type_info_init_str =
1961+
ConvertCoreOpsInfosToString(core_ops_args_type_info);
19431962
std::string core_ops_returns_info_init_str =
19441963
ConvertCoreOpsInfosToString(core_ops_returns_info);
19451964

19461965
std::string core_ops_info_str = paddle::string::Sprintf(
19471966
Core_Ops_Returns_MAP_TEMPLATE, core_ops_args_info_init_str,
1948-
core_ops_returns_info_init_str);
1967+
core_ops_args_type_info_init_str, core_ops_returns_info_init_str);
19491968

19501969
return core_ops_info_str;
19511970
}

paddle/fluid/framework/variable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ class Variable {
6464
void Clear() { holder_.reset(); }
6565

6666
int Type() const {
67+
VLOG(1) << 11111;
6768
PADDLE_ENFORCE_NOT_NULL(
6869
holder_, platform::errors::NotFound("Variable is not initialized."));
70+
VLOG(1) << (holder_ == nullptr);
6971
return holder_->Type();
7072
}
7173

paddle/fluid/pybind/eager_op_function_generator.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,21 @@ static std::string GenerateCoreOpsInfoMap() {
313313
" }\n"
314314
"}\n"
315315
"\n"
316+
"static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n"
317+
" PyThreadState *tstate = nullptr;\n"
318+
" try\n"
319+
" {\n"
320+
" return ToPyObject(core_ops_args_type_info);\n"
321+
" }\n"
322+
" catch(...) {\n"
323+
" if (tstate) {\n"
324+
" PyEval_RestoreThread(tstate);\n"
325+
" }\n"
326+
" ThrowExceptionToPython(std::current_exception());\n"
327+
" return nullptr;\n"
328+
" }\n"
329+
"}\n"
330+
"\n"
316331
"static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n"
317332
" PyThreadState *tstate = nullptr;\n"
318333
" try\n"
@@ -399,6 +414,10 @@ int main(int argc, char* argv[]) {
399414
"{\"get_core_ops_args_info\", "
400415
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, "
401416
"\"C++ interface function for eager_get_core_ops_args_info.\"},\n"
417+
"{\"get_core_ops_args_type_info\", "
418+
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, "
419+
"METH_NOARGS, "
420+
"\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n"
402421
" {\"get_core_ops_returns_info\", "
403422
"(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, "
404423
"METH_NOARGS, \"C++ interface function for "

python/paddle/fluid/dygraph/tracer.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,68 @@ def trace_op(self,
4848
stop_gradient=False,
4949
inplace_map=None):
5050
if framework._in_eager_mode():
51+
# inputs : {"sum": [tensor], ...}
52+
# outputs : {"sum": [tensor], ...}
53+
5154
function_ptr = _C_ops.__dict__[type]
5255

5356
core_ops_args_info = _C_ops.get_core_ops_args_info()
57+
core_ops_args_type_info = _C_ops.get_core_ops_args_type_info()
5458
core_ops_returns_info = _C_ops.get_core_ops_returns_info()
5559

5660
op_args = core_ops_args_info[type]
61+
op_args_type = core_ops_args_type_info[type]
5762
op_returns = core_ops_returns_info[type]
5863

5964
arg_list = []
60-
for arg in op_args:
61-
if arg in inputs.keys():
62-
arg_list.append(inputs[arg])
63-
elif arg in outputs.keys():
64-
arg_list.append(outputs[arg])
65+
for i in range(len(op_args)):
66+
arg_name = op_args[i]
67+
arg_type = op_args_type[i]
68+
if arg_name in inputs.keys():
69+
arg_to_append = inputs[arg_name]
70+
elif arg_name in outputs.keys():
71+
arg_to_append = outputs[arg_name]
6572
else:
66-
if "Num" in arg:
73+
if "Num" in arg_name:
6774
# Remove "Num" suffix to get out_name
68-
out_name = arg[:-3]
75+
out_name = arg_name[:-3]
6976
assert out_name in outputs.keys()
7077
num_outs = len(outputs[out_name])
71-
arg_list.append(num_outs)
78+
arg_to_append = num_outs
7279
else:
73-
arg_list.append(None)
74-
returns = function_ptr(*arg_list, **attrs)
75-
76-
for i in range(len(op_returns)):
77-
retname = op_returns[i]
78-
if retname in outputs.keys():
79-
# Replaced outputs by function returns
80-
outputs[retname] = returns[i]
80+
arg_to_append = None
81+
82+
if arg_to_append is None:
83+
arg_list.append(arg_to_append)
84+
elif arg_type == "tensor":
85+
arg_list.append(arg_to_append[0])
86+
elif arg_type == "list":
87+
arg_list.append(arg_to_append)
88+
elif arg_type == "int":
89+
arg_list.append(arg_to_append)
90+
else:
91+
assert False
92+
93+
attrs_list = []
94+
for k, v in attrs.items():
95+
attrs_list.append(k)
96+
attrs_list.append(v)
97+
returns = function_ptr(*arg_list, *attrs_list)
98+
99+
if isinstance(returns, tuple):
100+
for i in range(len(op_returns)):
101+
retname = op_returns[i]
102+
if retname in outputs.keys():
103+
# Replaced outputs by function returns
104+
outputs[retname] = returns[i]
105+
elif isinstance(returns, list):
106+
assert len(outputs.keys()) == 1
107+
key = list(outputs.keys())[0]
108+
outputs[key] = returns
109+
else:
110+
assert len(outputs.keys()) == 1
111+
key = list(outputs.keys())[0]
112+
outputs[key] = [returns]
81113
else:
82114
self.trace(type, inputs, outputs, attrs,
83115
framework._current_expected_place(), self._has_grad and

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import itertools
2626
import collections
2727
from collections import defaultdict
28+
from copy import copy
2829

2930
import paddle
3031
import paddle.fluid as fluid
@@ -491,7 +492,7 @@ def _append_ops(self, block):
491492
type=self.op_type,
492493
inputs=inputs,
493494
outputs=outputs,
494-
attrs=self.attrs if hasattr(self, "attrs") else dict())
495+
attrs=copy(self.attrs) if hasattr(self, "attrs") else dict())
495496
# infer variable type and infer shape in compile-time
496497
op.desc.infer_var_type(block.desc)
497498
op.desc.infer_shape(block.desc)
@@ -1172,8 +1173,7 @@ def find_actual(target_name, fetch_list):
11721173
if check_dygraph:
11731174
imperative_actual = find_imperative_actual(
11741175
sub_out_name, dygraph_outs, place)
1175-
imperative_actual_t = np.array(imperative_actual.value()
1176-
.get_tensor())
1176+
imperative_actual_t = imperative_actual.numpy()
11771177
idx = find_actual(sub_out_name, fetch_list)
11781178
actual = outs[idx]
11791179
actual_t = np.array(actual)
@@ -1209,8 +1209,7 @@ def find_actual(target_name, fetch_list):
12091209
if check_dygraph:
12101210
imperative_actual = find_imperative_actual(
12111211
out_name, dygraph_outs, place)
1212-
imperative_actual_t = np.array(imperative_actual.value()
1213-
.get_tensor())
1212+
imperative_actual_t = imperative_actual.numpy()
12141213
idx = find_actual(out_name, fetch_list)
12151214
actual = outs[idx]
12161215
actual_t = np.array(actual)

0 commit comments

Comments
 (0)