Skip to content

Commit ad037ca

Browse files
authored
[PHI] Migrate shard_index op (#40254)
1 parent 8cabb9f commit ad037ca

File tree

9 files changed

+268
-207
lines changed

9 files changed

+268
-207
lines changed

paddle/fluid/operators/shard_index_op.cc

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,17 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/shard_index_op.h"
15+
#include "paddle/fluid/framework/infershape_utils.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/unary.h"
1619

1720
namespace paddle {
1821
namespace operators {
1922

2023
class ShardIndexOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
23-
void InferShape(framework::InferShapeContext* ctx) const override {
24-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShardIndex");
25-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShardIndex");
26-
27-
auto x_dims = ctx->GetInputDim("X");
28-
PADDLE_ENFORCE_GE(x_dims.size(), 2,
29-
platform::errors::InvalidArgument(
30-
"Rank of Input(X) should be at least 2, "
31-
"but the value given is %d.",
32-
x_dims.size()));
33-
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
34-
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
35-
platform::errors::InvalidArgument(
36-
"The last dimension of Input(X) should be 1, "
37-
"but the value given is %d.",
38-
x_dims[x_dims.size() - 1]));
39-
}
40-
41-
ctx->SetOutputDim("Out", x_dims);
42-
ctx->ShareLoD("X", /* --> */ "Out");
43-
}
4426

4527
protected:
4628
framework::OpKernelType GetExpectedKernelType(
@@ -114,7 +96,10 @@ the original index should be recalculated (i.e. sharded) before.
11496
} // namespace paddle
11597

11698
namespace ops = paddle::operators;
117-
REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp,
118-
ops::ShardIndexOpMaker);
119-
REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel<int>,
120-
ops::ShardIndexCPUKernel<int64_t>);
99+
DECLARE_INFER_SHAPE_FUNCTOR(shard_index, ShardIndexInferShapeFunctor,
100+
PD_INFER_META(phi::ShardIndexInferMeta));
101+
REGISTER_OPERATOR(
102+
shard_index, ops::ShardIndexOp, ops::ShardIndexOpMaker,
103+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
104+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
105+
ShardIndexInferShapeFunctor);

paddle/fluid/operators/shard_index_op.cu

Lines changed: 0 additions & 96 deletions
This file was deleted.

paddle/fluid/operators/shard_index_op.h

Lines changed: 0 additions & 84 deletions
This file was deleted.

paddle/fluid/operators/shard_index_op_npu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/shard_index_op.h"
15+
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
1717

1818
namespace paddle {

paddle/phi/infermeta/unary.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,34 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
13121312
out->set_dtype(DataType::INT64);
13131313
}
13141314

1315+
void ShardIndexInferMeta(const MetaTensor& in,
1316+
int index_num,
1317+
int nshards,
1318+
int shard_id,
1319+
int ignore_value,
1320+
MetaTensor* out,
1321+
MetaConfig config) {
1322+
auto x_dims = in.dims();
1323+
PADDLE_ENFORCE_GE(
1324+
x_dims.size(),
1325+
2,
1326+
phi::errors::InvalidArgument("Rank of Input(X) should be at least 2, "
1327+
"but the value given is %d.",
1328+
x_dims.size()));
1329+
if (config.is_runtime || x_dims[x_dims.size() - 1] > 0) {
1330+
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1],
1331+
1U,
1332+
phi::errors::InvalidArgument(
1333+
"The last dimension of Input(X) should be 1, "
1334+
"but the value given is %d.",
1335+
x_dims[x_dims.size() - 1]));
1336+
}
1337+
1338+
out->set_dims(x_dims);
1339+
out->share_lod(in);
1340+
out->set_dtype(in.dtype());
1341+
}
1342+
13151343
} // namespace phi
13161344

13171345
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);

paddle/phi/infermeta/unary.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,12 @@ void EighInferMeta(const MetaTensor& x,
190190

191191
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
192192

193+
void ShardIndexInferMeta(const MetaTensor& in,
194+
int index_num,
195+
int nshards,
196+
int shard_id,
197+
int ignore_value,
198+
MetaTensor* out,
199+
MetaConfig config = MetaConfig());
200+
193201
} // namespace phi
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/shard_index_kernel.h"
16+
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void ShardIndexKernel(const Context& dev_ctx,
24+
const DenseTensor& in,
25+
int index_num,
26+
int nshards,
27+
int shard_id,
28+
int ignore_value,
29+
DenseTensor* out) {
30+
PADDLE_ENFORCE_GT(
31+
index_num,
32+
0,
33+
errors::InvalidArgument(
34+
"The value 'index_num' for Op(shard_index) must be greater than 0, "
35+
"but the value given is %d.",
36+
index_num));
37+
PADDLE_ENFORCE_GT(
38+
nshards,
39+
0,
40+
errors::InvalidArgument("The value 'nshard' for Op(shard_index) must be "
41+
"greater than 0, but the value given is %d.",
42+
nshards));
43+
PADDLE_ENFORCE_GE(
44+
shard_id,
45+
0,
46+
errors::InvalidArgument(
47+
"The value 'shard_id' for Op(shard_index) must be greater or "
48+
"equal to 0, but the value given is %d.",
49+
shard_id));
50+
PADDLE_ENFORCE_LT(
51+
shard_id,
52+
nshards,
53+
errors::InvalidArgument(
54+
"The value 'shard_id' for Op(shard_index) must be less than "
55+
"nshards (%d), but the value given is %d.",
56+
nshards,
57+
shard_id));
58+
59+
int shard_size = (index_num + nshards - 1) / nshards;
60+
61+
out->Resize(in.dims());
62+
out->set_lod(in.lod());
63+
auto* in_data = in.data<T>();
64+
auto* out_data = dev_ctx.template Alloc<T>(out);
65+
int64_t numel = in.numel();
66+
for (int64_t i = 0; i < numel; ++i) {
67+
PADDLE_ENFORCE_GE(in_data[i],
68+
0,
69+
errors::InvalidArgument(
70+
"The input_index for Op(shard_index) must be "
71+
"greater or equal to 0, but the value given is %d.",
72+
in_data[i]));
73+
PADDLE_ENFORCE_LT(in_data[i],
74+
index_num,
75+
errors::InvalidArgument(
76+
"The input_index for Op(shard_index) must be less "
77+
"than index_num (%d), but the value given is %d.",
78+
index_num,
79+
in_data[i]));
80+
if (in_data[i] / shard_size == shard_id) {
81+
out_data[i] = in_data[i] % shard_size;
82+
} else {
83+
out_data[i] = ignore_value;
84+
}
85+
}
86+
}
87+
88+
} // namespace phi
89+
90+
PD_REGISTER_KERNEL(
91+
shard_index, CPU, ALL_LAYOUT, phi::ShardIndexKernel, int, int64_t) {}

0 commit comments

Comments
 (0)