@@ -95,12 +95,12 @@ class OperatorBase {
9595 const VariableNameMap& Inputs () const { return inputs_; }
9696 const VariableNameMap& Outputs () const { return outputs_; }
9797 // ! Get a input with argument's name described in `op_proto`
98- const std::string& Input (const std::string& name) const ;
98+ std::string Input (const std::string& name) const ;
9999 // ! Get a input which has multiple variables.
100100 const std::vector<std::string>& Inputs (const std::string& name) const ;
101101
102102 // ! Get a output with argument's name described in `op_proto`
103- const std::string& Output (const std::string& name) const ;
103+ std::string Output (const std::string& name) const ;
104104 // ! Get an output which has multiple variables.
105105 // ! TODO add a vector_view to prevent memory copy.
106106 const std::vector<std::string>& Outputs (const std::string& name) const ;
@@ -127,6 +127,10 @@ class OperatorBase {
127127 // IG (Inputs Gradients)
128128 VariableNameMap outputs_;
129129 AttributeMap attrs_;
130+
131+ private:
132+ void GenerateTemporaryNames ();
133+ void CheckAllInputOutputSet () const ;
130134};
131135
132136// Macro for define a clone method.
@@ -238,46 +242,50 @@ class InferShapeContext {
238242 }
239243
240244 const Variable* InputVar (const std::string& name) const {
241- return scope_.FindVar (op_.Input (name));
245+ auto ipt = op_.Input (name);
246+ return ipt == kEmptyVarName ? nullptr : scope_.FindVar (ipt);
242247 }
243248
244249 Variable* OutputVar (const std::string& name) const {
245- return scope_.FindVar (op_.Output (name));
250+ auto opt = op_.Output (name);
251+ return opt == kEmptyVarName ? nullptr : scope_.FindVar (opt);
246252 }
247253
248254 const std::vector<const Variable*> MultiInputVar (
249255 const std::string& name) const {
250256 auto names = op_.Inputs (name);
251257 std::vector<const Variable*> res;
252258 res.reserve (names.size ());
253- std::transform (
254- names.begin (), names.end (), std::back_inserter (res),
255- [this ](const std::string& name) { return scope_.FindVar (name); });
259+ std::transform (names.begin (), names.end (), std::back_inserter (res),
260+ [this ](const std::string& name) {
261+ return name == kEmptyVarName ? nullptr
262+ : scope_.FindVar (name);
263+ });
256264 return res;
257265 }
258266
259267 std::vector<const Variable*> MultiOutputVar (const std::string& name) const {
260268 auto names = op_.Outputs (name);
261269 std::vector<const Variable*> res;
262270 res.reserve (names.size ());
263- std::transform (
264- names.begin (), names.end (), std::back_inserter (res),
265- [this ](const std::string& name) { return scope_.FindVar (name); });
271+ std::transform (names.begin (), names.end (), std::back_inserter (res),
272+ [this ](const std::string& name) {
273+ return name == kEmptyVarName ? nullptr
274+ : scope_.FindVar (name);
275+ });
266276 return res;
267277 }
268278
269279 template <typename T>
270280 const T* Input (const std::string& name) const {
271281 auto * var = InputVar (name);
272- PADDLE_ENFORCE_NOT_NULL (var, " Input(%s) should not be nullptr" , name);
273- return &var->Get <T>();
282+ return var == nullptr ? nullptr : &var->Get <T>();
274283 }
275284
276285 template <typename T>
277286 T* Output (const std::string& name) const {
278287 auto var = OutputVar (name);
279- PADDLE_ENFORCE_NOT_NULL (var, " Output(%s) should not be nullptr" , name);
280- return var->GetMutable <T>();
288+ return var == nullptr ? nullptr : var->GetMutable <T>();
281289 }
282290
283291 template <typename T>
@@ -288,10 +296,7 @@ class InferShapeContext {
288296 std::transform (names.begin (), names.end (), std::back_inserter (res),
289297 [&](const std::string& sub_name) {
290298 auto var = scope_.FindVar (sub_name);
291- PADDLE_ENFORCE_NOT_NULL (
292- var, " MultiInput(%s:%s) should not be nullptr" , name,
293- sub_name);
294- return &var->Get <T>();
299+ return var == nullptr ? nullptr : &var->Get <T>();
295300 });
296301 return res;
297302 }
@@ -304,10 +309,7 @@ class InferShapeContext {
304309 std::transform (names.begin (), names.end (), std::back_inserter (res),
305310 [&](const std::string& sub_name) {
306311 auto var = scope_.FindVar (sub_name);
307- PADDLE_ENFORCE_NOT_NULL (
308- var, " MultiOutput(%s:%s) should not be nullptr." , name,
309- sub_name);
310- return var->GetMutable <T>();
312+ return var == nullptr ? nullptr : var->GetMutable <T>();
311313 });
312314 return res;
313315 }
0 commit comments