- Notifications
You must be signed in to change notification settings - Fork 5.9k
Closed
Description
Now our shape inference is a very winding. Here is an example: to get some non-duplicatable input's dim in compile time, our program has to switch repeatedly between CompileTimeInferShapeContext and InferShapeContext. The call stack is like this:
(Calling from top to bottom)
DDim CompileTimeInferShapeContext::GetInputDim(string param_name); | vector<DDim> InferShapeContext::GetInputsDim(string param_name); | vector<DDim> InferShapeContext::GetDims(vectir<string> arg_name); | DDim CompileTimeInferShapeContext::GetDim(string arg_name); | DDim VerDesc::Shape(); What happens at runtime is similar, just replace CompileTimeInferShapeContext with RuntimeInferShapeContext.
There are at least two issues here:
- For a non-duplicatable input, the
CompileTimeInferShapeContext::GetInputDimis invoked, while for a duplicatable input theInferShapeContext::GetInputsDimwill be directly invoked. It mean's the entries of the same function dispersed in two class. vector<DDim> InferShapeContext::GetInputsDim(string param_name);is intended for duplicatable inputs. However, even the input is non-duplicatable, it will still be invoked indirectly. That is inefficient.
Solution:
Move the GetInputDim from CompileTimeInferShapeContext(and RunimeInferShapeContext) to InferShapeContext. And the GetInputDim invokes GetDim directly instead of detouring via GetInputsDim.
Metadata
Metadata
Assignees
Labels
No labels