Skip to content

The Infer-shape process can be simplified #8088

@JiayiFeng

Description

@JiayiFeng

Now our shape inference is a very winding. Here is an example: to get some non-duplicatable input's dim in compile time, our program has to switch repeatedly between CompileTimeInferShapeContext and InferShapeContext. The call stack is like this:
(Calling from top to bottom)

DDim CompileTimeInferShapeContext::GetInputDim(string param_name); | vector<DDim> InferShapeContext::GetInputsDim(string param_name); | vector<DDim> InferShapeContext::GetDims(vectir<string> arg_name); | DDim CompileTimeInferShapeContext::GetDim(string arg_name); | DDim VerDesc::Shape(); 

What happens at runtime is similar, just replace CompileTimeInferShapeContext with RuntimeInferShapeContext.

There are at least two issues here:

  1. For a non-duplicatable input, the CompileTimeInferShapeContext::GetInputDim is invoked, while for a duplicatable input the InferShapeContext::GetInputsDim will be directly invoked. It mean's the entries of the same function dispersed in two class.
  2. vector<DDim> InferShapeContext::GetInputsDim(string param_name); is intended for duplicatable inputs. However, even the input is non-duplicatable, it will still be invoked indirectly. That is inefficient.

Solution:

Move the GetInputDim from CompileTimeInferShapeContext(and RunimeInferShapeContext) to InferShapeContext. And the GetInputDim invokes GetDim directly instead of detouring via GetInputsDim.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions