@@ -16,15 +16,51 @@ limitations under the License. */
1616#include  < functional> 
1717#include  < mutex> 
1818#include  < unordered_map> 
19+ #include  " glog/logging.h" 
1920#include  " paddle/framework/block_desc.h" 
2021#include  " paddle/framework/operator.h" 
2122#include  " paddle/framework/program_desc.h" 
22- 
23- #include  " glog/logging.h" 
23+ #include  " paddle/framework/shape_inference.h" 
2424
2525namespace  paddle  {
2626namespace  framework  {
2727
28+ class  OpDescBind ;
29+ class  BlockDescBind ;
30+ class  CompileTimeInferShapeContext  : public  InferShapeContext  {
31+  public: 
32+  CompileTimeInferShapeContext (const  OpDescBind &op,
33+  const  BlockDescBind &block);
34+ 
35+  bool  HasInput (const  std::string &name) const  override ;
36+ 
37+  bool  HasOutput (const  std::string &name) const  override ;
38+ 
39+  bool  HasInputs (const  std::string &name) const  override ;
40+ 
41+  bool  HasOutputs (const  std::string &name) const  override ;
42+ 
43+  DDim GetInputDim (const  std::string &name) const  override ;
44+ 
45+  void  SetOutputDim (const  std::string &name, const  DDim &dim) override ;
46+ 
47+  AttrReader Attrs () const  override ;
48+ 
49+  const  std::vector<std::string> &Inputs (
50+  const  std::string &name) const  override ;
51+ 
52+  const  std::vector<std::string> &Outputs (
53+  const  std::string &name) const  override ;
54+ 
55+  private: 
56+  DDim GetDim (const  std::string &name) const  override ;
57+ 
58+  void  SetDim (const  std::string &name, const  DDim &dim) override ;
59+ 
60+  const  OpDescBind &op_;
61+  const  BlockDescBind &block_;
62+ };
63+ 
2864OpDescBind::OpDescBind (const  std::string &type, const  VariableNameMap &inputs,
2965 const  VariableNameMap &outputs,
3066 const  AttributeMap &attrs) {
@@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
288324 }
289325}
290326
327+ CompileTimeInferShapeContext::CompileTimeInferShapeContext (
328+  const  OpDescBind &op, const  BlockDescBind &block)
329+  : op_(op), block_(block) {}
330+ 
331+ bool  CompileTimeInferShapeContext::HasInput (const  std::string &name) const  {
332+  const  std::vector<std::string> &input_names = op_.Input (name);
333+  auto  length = input_names.size ();
334+  if  (length == 0 ) {
335+  return  false ;
336+  }
337+  PADDLE_ENFORCE_EQ (length, 1UL ,
338+  " Input(%s) should have only one value, " 
339+  " but it have %d now" 
340+  name, length);
341+  return  block_.HasVarRecursive (input_names[0 ]);
342+ }
343+ 
344+ bool  CompileTimeInferShapeContext::HasOutput (const  std::string &name) const  {
345+  const  std::vector<std::string> &output_names = op_.Output (name);
346+  auto  length = output_names.size ();
347+  if  (length == 0 ) {
348+  return  false ;
349+  }
350+  PADDLE_ENFORCE_EQ (length, 1UL ,
351+  " Output(%s) should have only one value, " 
352+  " but it have %d now" 
353+  name, length);
354+  return  block_.HasVarRecursive (output_names[0 ]);
355+ }
356+ 
357+ bool  CompileTimeInferShapeContext::HasInputs (const  std::string &name) const  {
358+  const  std::vector<std::string> &input_names = op_.Input (name);
359+  if  (input_names.empty ()) {
360+  return  false ;
361+  }
362+  for  (auto  &input : input_names) {
363+  if  (!block_.HasVarRecursive (input)) return  false ;
364+  }
365+  return  true ;
366+ }
367+ 
368+ bool  CompileTimeInferShapeContext::HasOutputs (const  std::string &name) const  {
369+  const  std::vector<std::string> &output_names = op_.Output (name);
370+  if  (output_names.empty ()) {
371+  return  false ;
372+  }
373+  for  (auto  &output : output_names) {
374+  if  (!block_.HasVarRecursive (output)) return  false ;
375+  }
376+  return  true ;
377+ }
378+ 
379+ DDim CompileTimeInferShapeContext::GetInputDim (const  std::string &name) const  {
380+  std::vector<DDim> ddims = GetInputsDim (name);
381+  auto  length = ddims.size ();
382+  PADDLE_ENFORCE_EQ (length, 1UL ,
383+  " Input(%s) should have 1 value, " 
384+  " but it has %d now" 
385+  name, length);
386+  return  ddims[0 ];
387+ }
388+ 
389+ void  CompileTimeInferShapeContext::SetOutputDim (const  std::string &name,
390+  const  DDim &dim) {
391+  SetOutputsDim (name, {dim});
392+ }
393+ 
394+ AttrReader CompileTimeInferShapeContext::Attrs () const  {
395+  return  AttrReader (op_.GetAttrMap ());
396+ }
397+ 
398+ const  std::vector<std::string> &CompileTimeInferShapeContext::Inputs (
399+  const  std::string &name) const  {
400+  return  op_.Input (name);
401+ }
402+ 
403+ const  std::vector<std::string> &CompileTimeInferShapeContext::Outputs (
404+  const  std::string &name) const  {
405+  return  op_.Output (name);
406+ }
407+ 
408+ DDim CompileTimeInferShapeContext::GetDim (const  std::string &name) const  {
409+  auto  var = block_.FindVarRecursive (name);
410+  PADDLE_ENFORCE (var != nullptr , " Cannot find variable %s" 
411+  return  framework::make_ddim (var->Shape ());
412+ }
413+ 
414+ void  CompileTimeInferShapeContext::SetDim (const  std::string &name,
415+  const  DDim &dim) {
416+  block_.FindVarRecursive (name)->SetShape (framework::vectorize (dim));
417+ }
418+ 
291419} //  namespace framework
292420} //  namespace paddle
0 commit comments