Skip to content

Commit d7a1e40

Browse files
committed
Simple Implementation
1 parent fd8df08 commit d7a1e40

File tree

2 files changed

+14
-34
lines changed

2 files changed

+14
-34
lines changed

paddle/framework/operator.cc

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
3535

3636
std::string OperatorBase::Input(const std::string& name) const {
3737
auto& ins = Inputs(name);
38-
switch (ins.size()) {
39-
case 0:
40-
return kEmptyVarName;
41-
case 1:
42-
return ins[0];
43-
default:
44-
PADDLE_THROW("Op %s input %s should contain only one variable", type_,
45-
name);
46-
return "";
47-
}
38+
PADDLE_ENFORCE_LE(ins.size(), 1UL,
39+
"Op %s input %s should contain only one variable", type_,
40+
name);
41+
return ins.empty() ? kEmptyVarName : ins[0];
4842
}
4943

5044
const std::vector<std::string>& OperatorBase::Inputs(
@@ -57,16 +51,10 @@ const std::vector<std::string>& OperatorBase::Inputs(
5751

5852
std::string OperatorBase::Output(const std::string& name) const {
5953
auto& outs = Outputs(name);
60-
switch (outs.size()) {
61-
case 0:
62-
return kEmptyVarName;
63-
case 1:
64-
return outs[0];
65-
default:
66-
PADDLE_THROW("Op %s output %s should contain only one variable", type_,
67-
name);
68-
return "";
69-
}
54+
PADDLE_ENFORCE_LE(outs.size(), 1UL,
55+
"Op %s output %s should contain only one variable", type_,
56+
name);
57+
return outs.empty() ? kEmptyVarName : outs[0];
7058
}
7159

7260
const std::vector<std::string>& OperatorBase::Outputs(

paddle/framework/operator.h

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,20 +239,12 @@ class InferShapeContext {
239239

240240
const Variable* InputVar(const std::string& name) const {
241241
auto ipt = op_.Input(name);
242-
if (ipt == kEmptyVarName) {
243-
return nullptr;
244-
} else {
245-
return scope_.FindVar(ipt);
246-
}
242+
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
247243
}
248244

249245
Variable* OutputVar(const std::string& name) const {
250246
auto opt = op_.Output(name);
251-
if (opt == kEmptyVarName) {
252-
return nullptr;
253-
} else {
254-
return scope_.FindVar(opt);
255-
}
247+
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
256248
}
257249

258250
const std::vector<const Variable*> MultiInputVar(
@@ -262,8 +254,8 @@ class InferShapeContext {
262254
res.reserve(names.size());
263255
std::transform(names.begin(), names.end(), std::back_inserter(res),
264256
[this](const std::string& name) {
265-
return name != kEmptyVarName ? scope_.FindVar(name)
266-
: nullptr;
257+
return name == kEmptyVarName ? nullptr
258+
: scope_.FindVar(name);
267259
});
268260
return res;
269261
}
@@ -274,8 +266,8 @@ class InferShapeContext {
274266
res.reserve(names.size());
275267
std::transform(names.begin(), names.end(), std::back_inserter(res),
276268
[this](const std::string& name) {
277-
return name != kEmptyVarName ? scope_.FindVar(name)
278-
: nullptr;
269+
return name == kEmptyVarName ? nullptr
270+
: scope_.FindVar(name);
279271
});
280272
return res;
281273
}

0 commit comments

Comments
 (0)