Skip to content

Commit 7d6096f

Browse files
authored
【Pten】Auto-Generate InterMeta register (#39436)
* fix code conflict * generate inter_meta register * clear cache * just try * add sign c++ api * polish some code
1 parent 1252f4b commit 7d6096f

File tree

10 files changed

+97
-38
lines changed

10 files changed

+97
-38
lines changed

paddle/pten/core/infermeta_utils.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
6666
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
6767
return *inputs_.at(idx);
6868
}
69+
70+
std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
71+
size_t end) const {
72+
std::vector<MetaTensor> result;
73+
result.reserve(end - start);
74+
75+
for (size_t i = start; i < end; ++i) {
76+
result.emplace_back(*inputs_.at(i));
77+
}
78+
79+
return result;
80+
}
81+
6982
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
7083
return outputs_.at(idx).get();
7184
}

paddle/pten/core/infermeta_utils.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License. */
1717
#include <string>
1818
#include <utility>
1919

20+
#include "paddle/pten/common/scalar.h"
21+
#include "paddle/pten/common/scalar_array.h"
2022
#include "paddle/pten/core/enforce.h"
2123
#include "paddle/pten/core/macros.h"
2224
#include "paddle/pten/core/meta_tensor.h"
@@ -46,6 +48,7 @@ class InferMetaContext {
4648

4749
const MetaConfig& GetMetaConfig() const;
4850
const MetaTensor& InputAt(size_t idx) const;
51+
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
4952
MetaTensor* MutableOutputAt(size_t idx);
5053

5154
template <typename AttrType>
@@ -85,7 +88,8 @@ class InferMetaContext {
8588
"InferMeta's Attributes should appear before Outputs."); \
8689
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
8790
InferMetaFnCallHelper< \
88-
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(pargs..., \
91+
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
92+
pargs..., \
8993
arg); \
9094
} \
9195
}
@@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
124128
}
125129
};
126130

131+
template <typename... Tail>
132+
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
133+
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
134+
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
135+
static_assert(attr_idx == 0,
136+
"InferMeta's Input should appear before Attributes.");
137+
static_assert(out_idx == 0,
138+
"InferMeta's Input should appear before Outputs.");
139+
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
140+
std::vector<MetaTensor> arg =
141+
ctx->InputsBetween(range.first, range.second);
142+
InferMetaFnCallHelper<
143+
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
144+
pargs...,
145+
arg);
146+
}
147+
};
148+
149+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
150+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
151+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
152+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
153+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
154+
const std::vector<int64_t>&);
155+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
156+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
157+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
158+
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
159+
127160
// TODO(chenweihang): support vector<MetaTensor> input later
128161

129162
template <typename... Tail>
@@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
227260
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \
228261
static const ::pten::InferMetaFnRegistrar \
229262
__registrar_arg_map_fn_for_##kernel_name_prefix( \
230-
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
231-
int TouchInferMetaFnSymbol_##op_type() { return 0; }
263+
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))
232264

233265
} // namespace pten

paddle/pten/infermeta/nullary.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ limitations under the License. */
1616

1717
namespace pten {
1818

19-
void CreateInferMeta(const std::vector<int64_t>& shape,
20-
DataType dtype,
21-
DataLayout layout,
22-
MetaTensor* out) {
19+
void CreateInferMetaBase(const std::vector<int64_t>& shape,
20+
DataType dtype,
21+
DataLayout layout,
22+
MetaTensor* out) {
2323
auto out_dims = pten::framework::make_ddim(shape);
2424
out->set_dims(out_dims);
2525
out->set_dtype(dtype);
@@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape,
3030
DataType dtype,
3131
DataLayout layout,
3232
MetaTensor* out) {
33-
CreateInferMeta(shape.GetData(), dtype, layout, out);
33+
CreateInferMetaBase(shape.GetData(), dtype, layout, out);
3434
}
3535

3636
} // namespace pten

paddle/pten/infermeta/nullary.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ namespace pten {
2828
// Because functions in this file not only can infer shape, but also need
2929
// infer lod or other useful data.
3030

31-
void CreateInferMeta(const std::vector<int64_t>& shape,
32-
DataType dtype,
33-
DataLayout layout,
34-
MetaTensor* out);
31+
void CreateInferMetaBase(const std::vector<int64_t>& shape,
32+
DataType dtype,
33+
DataLayout layout,
34+
MetaTensor* out);
3535

3636
void CreateInferMeta(const ScalarArray& shape,
3737
DataType dtype,

paddle/pten/infermeta/unary.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,14 @@ void SumInferMeta(const MetaTensor& x,
242242
DataType dtype,
243243
bool keep_dim,
244244
MetaTensor* out) {
245-
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out));
245+
ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
246246
}
247247

248-
void ReduceInferMeta(const MetaTensor& x,
249-
const std::vector<int64_t>& axis,
250-
bool keep_dim,
251-
DataType dtype,
252-
MetaTensor* out) {
248+
void ReduceInferMetaBase(const MetaTensor& x,
249+
const std::vector<int64_t>& axis,
250+
bool keep_dim,
251+
DataType dtype,
252+
MetaTensor* out) {
253253
bool reduce_all = true;
254254
std::set<int64_t> dims_set(axis.begin(), axis.end());
255255
for (int64_t i = 0; i < x.dims().size(); ++i) {
@@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x,
304304
const std::vector<int64_t>& axis,
305305
bool keep_dim,
306306
MetaTensor* out) {
307-
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
307+
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
308308
}
309309

310310
void TransferLayoutInferMeta(const MetaTensor& x,
@@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x,
316316
}
317317

318318
} // namespace pten
319-
320-
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);

paddle/pten/infermeta/unary.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ void ReshapeInferMeta(const MetaTensor& x,
5353
const ScalarArray& shape,
5454
MetaTensor* out);
5555

56-
void ReduceInferMeta(const MetaTensor& x,
57-
const std::vector<int64_t>& axis,
58-
bool keep_dim,
59-
DataType dtype,
60-
MetaTensor* out);
56+
void ReduceInferMetaBase(const MetaTensor& x,
57+
const std::vector<int64_t>& axis,
58+
bool keep_dim,
59+
DataType dtype,
60+
MetaTensor* out);
6161

6262
void ReduceInferMeta(const MetaTensor& x,
6363
const std::vector<int64_t>& axis,

paddle/pten/kernels/math_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
156156
bool keep_dim) {
157157
auto dense_out = pten::Empty<T, Context>(dev_ctx);
158158
MetaTensor meta_out(&dense_out);
159-
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out);
159+
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
160160
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
161161
return dense_out;
162162
}

python/paddle/utils/code_gen/api.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@
161161
kernel :
162162
func : scale
163163

164+
- api : sign
165+
args : (const Tensor& x)
166+
output : Tensor
167+
infer_meta :
168+
func : UnchangedInferMeta
169+
kernel :
170+
func : sign
171+
164172
- api : subtract
165173
args : (const Tensor& x, const Tensor& y)
166174
output : Tensor
@@ -173,10 +181,10 @@
173181
- api : sum
174182
args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
175183
output : Tensor
176-
infer_meta :
184+
infer_meta :
177185
func : SumInferMeta
178186
param: [x, axis, dtype, keep_dim]
179-
kernel :
187+
kernel :
180188
func : sum
181189
param : [x, axis, dtype, keep_dim]
182190
data_type : x

python/paddle/utils/code_gen/api_base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -379,14 +379,7 @@ def get_kernel_args(self):
379379
input_infos = self.inputs['input_info']
380380
kernel_args_type_list = ['const platform::DeviceContext&']
381381

382-
input_tensor_code = ""
383-
for input_name in input_names:
384-
# set input code
385-
input_tensor_code = input_tensor_code + f"""
386-
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
387-
388382
attr_names = self.attrs['names']
389-
390383
kernel_param = self.kernel['param']
391384
if kernel_param is None:
392385
kernel_param = input_names + attr_names
@@ -401,11 +394,11 @@ def get_kernel_args(self):
401394
elif input_name in self.data_transform['support_trans_dtype']:
402395
trans_flag = "{false, true}"
403396
input_tensor_code = input_tensor_code + f"""
404-
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
397+
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
405398

406399
else:
407400
input_tensor_code = input_tensor_code + f"""
408-
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
401+
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
409402

410403
kernel_args = "*dev_ctx, "
411404
for param in kernel_param:

python/paddle/utils/code_gen/api_gen.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ def gene_output(self, output_type_list):
6060

6161
return kernel_output, output_names, output_create
6262

63+
def gene_infer_meta_register(self):
64+
if self.is_base_api:
65+
return f"""
66+
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""
67+
68+
else:
69+
return ''
70+
6371

6472
def header_include():
6573
return """
@@ -83,6 +91,7 @@ def source_include(header_file_path):
8391
#include "paddle/pten/api/lib/data_transform.h"
8492
#include "paddle/pten/api/lib/kernel_dispatch.h"
8593
#include "paddle/pten/api/lib/utils/storage.h"
94+
#include "paddle/pten/core/infermeta_utils.h"
8695
#include "paddle/pten/core/kernel_registry.h"
8796
#include "paddle/pten/infermeta/binary.h"
8897
#include "paddle/pten/infermeta/multiary.h"
@@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
127136
source_file.write(source_include(include_header_file))
128137
source_file.write(namespace[0])
129138

139+
infer_meta_register_code = ''
140+
130141
for api in apis:
131142
api_code = ForwardAPI(api)
132143
print(api_code.gene_api_declaration())
133144
header_file.write(api_code.gene_api_declaration())
134145
source_file.write(api_code.gene_api_code())
146+
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
147+
)
135148

136149
header_file.write(namespace[1])
137150
source_file.write(namespace[1])
151+
138152
source_file.write(api_register())
153+
source_file.write(infer_meta_register_code)
139154

140155
header_file.close()
141156
source_file.close()

0 commit comments

Comments
 (0)