@@ -17,6 +17,8 @@ limitations under the License. */
1717#include < string>
1818#include < utility>
1919
20+ #include " paddle/pten/common/scalar.h"
21+ #include " paddle/pten/common/scalar_array.h"
2022#include " paddle/pten/core/enforce.h"
2123#include " paddle/pten/core/macros.h"
2224#include " paddle/pten/core/meta_tensor.h"
@@ -46,6 +48,7 @@ class InferMetaContext {
4648
4749 const MetaConfig& GetMetaConfig () const ;
4850 const MetaTensor& InputAt (size_t idx) const ;
51+ std::vector<MetaTensor> InputsBetween (size_t start, size_t end) const ;
4952 MetaTensor* MutableOutputAt (size_t idx);
5053
5154 template <typename AttrType>
@@ -85,7 +88,8 @@ class InferMetaContext {
8588 " InferMeta's Attributes should appear before Outputs." ); \
8689 attr_type arg = ctx->AttrAt <attr_type>(attr_idx); \
8790 InferMetaFnCallHelper< \
88- Tail...>::template Call<in_idx, attr_idx + 1 , out_idx>(pargs..., \
91+ Tail...>::template Call<in_idx, attr_idx + 1 , out_idx>(ctx, \
92+ pargs..., \
8993 arg); \
9094 } \
9195 }
@@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
124128 }
125129 };
126130
131+ template <typename ... Tail>
132+ struct InferMetaFnCallHelper <const std::vector<MetaTensor>&, Tail...> {
133+ template <int in_idx, int attr_idx, int out_idx, typename ... PreviousArgs>
134+ static void Call (InferMetaContext* ctx, PreviousArgs&... pargs) {
135+ static_assert (attr_idx == 0 ,
136+ " InferMeta's Input should appear before Attributes." );
137+ static_assert (out_idx == 0 ,
138+ " InferMeta's Input should appear before Outputs." );
139+ const std::pair<int , int > range = ctx->InputRangeAt (in_idx);
140+ std::vector<MetaTensor> arg =
141+ ctx->InputsBetween (range.first , range.second );
142+ InferMetaFnCallHelper<
143+ Tail...>::template Call<in_idx + 1 , attr_idx, out_idx>(ctx,
144+ pargs...,
145+ arg);
146+ }
147+ };
148+
149+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (bool );
150+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (int );
151+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (int64_t );
152+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (const std::vector<int >&);
153+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (
154+ const std::vector<int64_t >&);
155+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (DataType);
156+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (DataLayout);
157+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (const Scalar&);
158+ PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE (const ScalarArray&);
159+
127160 // TODO(chenweihang): support vector<MetaTensor> input later
128161
129162 template <typename ... Tail>
@@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
227260 " PT_REGISTER_INFER_META_FN must be called in global namespace." ); \
228261 static const ::pten::InferMetaFnRegistrar \
229262 __registrar_arg_map_fn_for_##kernel_name_prefix( \
230- #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
231- int TouchInferMetaFnSymbol_##op_type() { return 0 ; }
263+ #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))
232264
233265} // namespace pten
0 commit comments