Skip to content

Commit 6af2729

Browse files
authored
【phi】migrate gather_tree,reduce_prod to phi (#39844)
* move to phi * migrate gather_tree_op into phi * move reduce_prod tp phi * optimize code
1 parent 1db188f commit 6af2729

File tree

13 files changed

+285
-178
lines changed

13 files changed

+285
-178
lines changed

paddle/fluid/operators/gather_tree_op.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/gather_tree_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616

1717
namespace paddle {
1818
namespace operators {
@@ -73,5 +73,3 @@ selected ids.
7373

7474
namespace ops = paddle::operators;
7575
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
76-
REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel<int32_t>,
77-
ops::GatherTreeOpKernel<int64_t>);

paddle/fluid/operators/gather_tree_op.cu

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/fluid/operators/gather_tree_op.h

Lines changed: 0 additions & 66 deletions
This file was deleted.

paddle/fluid/operators/reduce_ops/reduce_prod_op.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,7 @@ class CPUDeviceContext;
2727
} // namespace paddle
2828

2929
REGISTER_REDUCE_OP(reduce_prod);
30-
REGISTER_OP_CPU_KERNEL(reduce_prod,
31-
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
32-
float, ops::ProdFunctor>,
33-
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
34-
double, ops::ProdFunctor>,
35-
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
36-
int, ops::ProdFunctor>,
37-
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
38-
int64_t, ops::ProdFunctor>);
30+
3931
REGISTER_OP_CPU_KERNEL(reduce_prod_grad,
4032
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
4133
float, ops::ProdGradFunctor>,

paddle/fluid/operators/reduce_ops/reduce_prod_op.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@
1919
namespace paddle {
2020
namespace operators {
2121

22-
struct ProdFunctor {
23-
template <typename DeviceContext, typename X, typename Y, typename Dim>
24-
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
25-
y->device(place) = x->prod(dim);
26-
}
27-
};
28-
2922
struct ProdGradFunctor {
3023
template <typename DeviceContext, typename X, typename Y, typename DX,
3124
typename DY, typename Dim>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/gather_tree_kernel.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
18+
namespace phi {
19+
20+
template <typename T, typename Context>
21+
void GatherTreeKernel(const Context &dev_ctx,
22+
const DenseTensor &ids,
23+
const DenseTensor &parents,
24+
DenseTensor *out) {
25+
const auto *ids_data = ids.data<T>();
26+
const auto *parents_data = parents.data<T>();
27+
28+
T *out_data = dev_ctx.template Alloc<T>(out);
29+
30+
auto &ids_dims = ids.dims();
31+
auto max_length = ids_dims[0];
32+
auto batch_size = ids_dims[1];
33+
auto beam_size = ids_dims[2];
34+
35+
PADDLE_ENFORCE_NOT_NULL(ids_data,
36+
phi::errors::InvalidArgument(
37+
"Input(Ids) of gather_tree should not be null."));
38+
39+
PADDLE_ENFORCE_NOT_NULL(
40+
parents_data,
41+
phi::errors::InvalidArgument(
42+
"Input(Parents) of gather_tree should not be null."));
43+
44+
for (int batch = 0; batch < batch_size; batch++) {
45+
for (int beam = 0; beam < beam_size; beam++) {
46+
auto idx =
47+
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
48+
out_data[idx] = ids_data[idx];
49+
auto parent = parents_data[idx];
50+
for (int step = max_length - 2; step >= 0; step--) {
51+
idx = step * batch_size * beam_size + batch * beam_size;
52+
out_data[idx + beam] = ids_data[idx + parent];
53+
parent = parents_data[idx + parent];
54+
}
55+
}
56+
}
57+
}
58+
59+
} // namespace phi
60+
61+
PD_REGISTER_KERNEL(
62+
gather_tree, CPU, ALL_LAYOUT, phi::GatherTreeKernel, int, int64_t) {}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/reduce_prod_kernel.h"
16+
#include "paddle/phi/backends/cpu/cpu_context.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/cpu/reduce.h"
19+
#include "paddle/phi/kernels/funcs/reduce_functor.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename Context>
24+
void ReduceProdKernel(const Context& dev_ctx,
25+
const DenseTensor& x,
26+
const std::vector<int64_t>& dims,
27+
bool keep_dim,
28+
bool reduce_all,
29+
DenseTensor* out) {
30+
auto out_dtype = x.dtype();
31+
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
32+
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
33+
}
34+
35+
} // namespace phi
36+
37+
PD_REGISTER_KERNEL(reduce_prod,
38+
CPU,
39+
ALL_LAYOUT,
40+
phi::ReduceProdKernel,
41+
float,
42+
double,
43+
int,
44+
int64_t) {}

paddle/phi/kernels/funcs/reduce_functor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,13 @@ struct MeanFunctor {
3333
}
3434
};
3535

36+
//////// Prod Functor ///////
37+
struct ProdFunctor {
38+
template <typename DeviceContext, typename X, typename Y, typename Dim>
39+
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
40+
y->device(place) = x->prod(dim);
41+
}
42+
};
43+
3644
} // namespace funcs
3745
} // namespace phi
Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,12 +12,15 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
16-
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
15+
#pragma once
1716

18-
REGISTER_OP_CUDA_KERNEL(
19-
reduce_prod,
20-
ops::ReduceCudaKernel<float, kps::MulFunctor, kps::IdentityFunctor>,
21-
ops::ReduceCudaKernel<int, kps::MulFunctor, kps::IdentityFunctor>,
22-
ops::ReduceCudaKernel<double, kps::MulFunctor, kps::IdentityFunctor>,
23-
ops::ReduceCudaKernel<int64_t, kps::MulFunctor, kps::IdentityFunctor>);
17+
#include "paddle/phi/core/dense_tensor.h"
18+
namespace phi {
19+
20+
template <typename T, typename Context>
21+
void GatherTreeKernel(const Context &dev_ctx,
22+
const DenseTensor &ids,
23+
const DenseTensor &parents,
24+
DenseTensor *out);
25+
26+
} // namespace phi

0 commit comments

Comments
 (0)