Skip to content

Commit a23ce42

Browse files
committed
[Auto Parallel] Add spmd rule for topk and topk_grad ops
1 parent 441816a commit a23ce42

File tree

8 files changed

+232
-0
lines changed

8 files changed

+232
-0
lines changed

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,11 @@ PD_REGISTER_SPMD_RULE(
705705
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdBase),
706706
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdReverseBase));
707707

708+
// argmax
709+
PD_REGISTER_SPMD_RULE(topk,
710+
PD_INFER_SPMD(phi::distributed::TopkInferSpmd),
711+
PD_INFER_SPMD(phi::distributed::TopkGradInferSpmd));
712+
708713
// unbind
709714
PD_REGISTER_SPMD_RULE(unbind,
710715
PD_INFER_SPMD(phi::distributed::UnbindInferSpmd),

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ limitations under the License. */
6666
#include "paddle/phi/infermeta/spmd_rules/squeeze.h"
6767
#include "paddle/phi/infermeta/spmd_rules/stack.h"
6868
#include "paddle/phi/infermeta/spmd_rules/tile.h"
69+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
6970
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
7071
#include "paddle/phi/infermeta/spmd_rules/triu.h"
7172
#include "paddle/phi/infermeta/spmd_rules/unbind.h"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
16+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
17+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
18+
19+
namespace phi {
20+
namespace distributed {
21+
22+
SpmdInfo TopkInferSpmd(const DistMetaTensor& x,
23+
const Scalar& k,
24+
int axis,
25+
bool largest,
26+
bool sorted) {
27+
// Verify input args
28+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
29+
axis = axis < 0 ? x_ndim + axis : axis;
30+
31+
// Infer output dims mapping from merged input dims mapping
32+
std::vector<int64_t> x_dims_mapping_dst(x_dims_mapping_src);
33+
std::vector<int64_t> out_dims_mapping;
34+
std::vector<int64_t> indices_dims_mapping;
35+
x_dims_mapping_dst[axis] = -1;
36+
out_dims_mapping.assign(x_dims_mapping_dst.begin(), x_dims_mapping_dst.end());
37+
indices_dims_mapping = out_dims_mapping;
38+
39+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
40+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
41+
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
42+
out_dist_attr.set_dims_mapping(out_dims_mapping);
43+
TensorDistAttr indices_dist_attr =
44+
CopyTensorDistAttrForOutput(x_dist_attr_src);
45+
indices_dist_attr.set_dims_mapping(indices_dims_mapping);
46+
47+
return {{x_dist_attr_dst}, {out_dist_attr, indices_dist_attr}};
48+
}
49+
50+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
51+
const DistMetaTensor& indices,
52+
Tensor out_grad,
53+
Scalar k,
54+
int axis,
55+
bool largest,
56+
bool sorted) {
57+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
58+
EXTRACT_SHAPE_AND_DIST_ATTR(indices);
59+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
60+
61+
TensorDistAttr out_grad_dist_attr_dst =
62+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
63+
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
64+
65+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
66+
x_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
67+
68+
TensorDistAttr indices_dist_attr_dst =
69+
CopyTensorDistAttrForOutput(indices_dist_attr_src);
70+
indices_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
71+
72+
TensorDistAttr x_grad_dist_attr_dst =
73+
CopyTensorDistAttrForOutput(x_dist_attr_src);
74+
x_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_src);
75+
return {{x_dist_attr_dst, indices_dist_attr_dst, out_grad_dist_attr_dst},
76+
{x_grad_dist_attr_dst}};
77+
}
78+
79+
} // namespace distributed
80+
} // namespace phi
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/phi/common/scalar.h"
19+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
20+
#include "paddle/phi/core/distributed/type_defs.h"
21+
22+
namespace phi {
23+
namespace distributed {
24+
25+
SpmdInfo TopkInferSpmd(const DistMetaTensor& x,
26+
const Scalar& k,
27+
int axis,
28+
bool largest,
29+
bool sorted);
30+
31+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
32+
const DistMetaTensor& indices,
33+
const DistMetaTensor& out_grad,
34+
Scalar k,
35+
int axis,
36+
bool largest,
37+
bool sorted);
38+
39+
} // namespace distributed
40+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,6 +3516,7 @@
35163516
infer_meta :
35173517
func : UnchangedInferMeta
35183518
param : [x]
3519+
spmd_rule: TopkGradInferSpmd
35193520
kernel :
35203521
func : topk_grad
35213522
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5128,6 +5128,7 @@
51285128
output : Tensor(out), Tensor(indices)
51295129
infer_meta :
51305130
func : TopKInferMeta
5131+
spmd_rule: TopkInferSpmd
51315132
kernel :
51325133
func : topk
51335134
data_type : x

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ if(WITH_DISTRIBUTE)
3535
py_test_modules(test_gather_rule MODULES test_gather_rule)
3636
py_test_modules(test_cumsum_rule MODULES test_cumsum_rule)
3737
py_test_modules(test_argmax_rule MODULES test_argmax_rule)
38+
py_test_modules(test_topk_rule MODULES test_topk_rule)
3839
py_test_modules(test_unbind_rule MODULES test_unbind_rule)
3940
py_test_modules(test_stack_rule MODULES test_stack_rule)
4041
py_test_modules(test_gather_nd_rule MODULES test_gather_nd_rule)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from collections import OrderedDict
17+
18+
from paddle.distributed.auto_parallel.static.dist_attribute import (
19+
DistTensorSpec,
20+
TensorDistAttr,
21+
)
22+
from paddle.distributed.fleet import auto
23+
from paddle.framework import core
24+
25+
26+
class TestTopkSPMDRule(unittest.TestCase):
27+
def setUp(self):
28+
x_shape = [16, 16, 16]
29+
out_shape = [16, 2, 16]
30+
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
31+
32+
x_tensor_dist_attr = TensorDistAttr()
33+
x_tensor_dist_attr.dims_mapping = [-1, -1, -1]
34+
x_tensor_dist_attr.process_mesh = process_mesh
35+
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
36+
out_tensor_dist_attr = TensorDistAttr()
37+
out_tensor_dist_attr.dims_mapping = [-1, -1, -1]
38+
out_tensor_dist_attr.process_mesh = process_mesh
39+
self.out_dist_tensor_spec = DistTensorSpec(
40+
out_shape, x_tensor_dist_attr
41+
)
42+
43+
self.rule = core.get_phi_spmd_rule("topk")
44+
self.attrs = OrderedDict()
45+
self.attrs['k'] = 2
46+
self.attrs['axis'] = 1
47+
self.attrs['largest'] = True
48+
self.attrs['sorted'] = True
49+
50+
def test_topk_forward(self):
51+
# axis = 1
52+
# [0, 1, -1] --> [0, -1, -1], [0, -1, -1]
53+
self.attrs['axis'] = 1
54+
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
55+
result_dist_attrs = self.rule.infer_forward(
56+
self.x_dist_tensor_spec,
57+
self.attrs['k'],
58+
self.attrs['axis'],
59+
self.attrs['largest'],
60+
self.attrs['sorted'],
61+
)
62+
inferred_input_dist_attrs = result_dist_attrs[0]
63+
inferred_output_dist_attrs = result_dist_attrs[1]
64+
65+
self.assertEqual(len(result_dist_attrs), 2)
66+
self.assertEqual(len(inferred_input_dist_attrs), 1)
67+
self.assertEqual(len(inferred_output_dist_attrs), 2)
68+
69+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
70+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
71+
self.assertEqual(
72+
inferred_output_dist_attrs[0].dims_mapping, [0, -1, -1]
73+
)
74+
75+
def test_topk_backward(self):
76+
# axis = 1
77+
# [0, -1, 1] --> [0, -1, 1], [0, -1, 1], [0, -1, 1]
78+
self.attrs['axis'] = 1
79+
self.out_dist_tensor_spec.shape = [16, 2, 16]
80+
self.out_dist_tensor_spec.set_dims_mapping([0, -1, 1])
81+
result_dist_attrs = self.rule.infer_backward(
82+
self.x_dist_tensor_spec,
83+
self.out_dist_tensor_spec,
84+
self.out_dist_tensor_spec,
85+
self.attrs['k'],
86+
self.attrs['axis'],
87+
self.attrs['largest'],
88+
self.attrs['sorted'],
89+
)
90+
inferred_input_dist_attrs = result_dist_attrs[0]
91+
inferred_output_dist_attrs = result_dist_attrs[1]
92+
self.assertEqual(len(result_dist_attrs), 2)
93+
self.assertEqual(len(inferred_input_dist_attrs), 3)
94+
self.assertEqual(len(inferred_output_dist_attrs), 1)
95+
96+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
97+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
98+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
99+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [0, -1, 1])
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

0 commit comments

Comments
 (0)