Skip to content

Commit 6fc7d44

Browse files
authored
[Auto Parallel] Add spmd rule No.1 for topk and topk_grad ops (#72499)
1 parent 387259f commit 6fc7d44

File tree

9 files changed

+375
-0
lines changed

9 files changed

+375
-0
lines changed

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,11 @@ PD_REGISTER_SPMD_RULE(
719719
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdBase),
720720
PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdReverseBase));
721721

722+
// topk
723+
PD_REGISTER_SPMD_RULE(topk,
724+
PD_INFER_SPMD(phi::distributed::TopkInferSpmd),
725+
PD_INFER_SPMD(phi::distributed::TopkGradInferSpmd));
726+
722727
// unbind
723728
PD_REGISTER_SPMD_RULE(unbind,
724729
PD_INFER_SPMD(phi::distributed::UnbindInferSpmd),

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ limitations under the License. */
6868
#include "paddle/phi/infermeta/spmd_rules/squeeze.h"
6969
#include "paddle/phi/infermeta/spmd_rules/stack.h"
7070
#include "paddle/phi/infermeta/spmd_rules/tile.h"
71+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
7172
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
7273
#include "paddle/phi/infermeta/spmd_rules/triu.h"
7374
#include "paddle/phi/infermeta/spmd_rules/unbind.h"
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/topk.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
18+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
23+
SpmdInfo TopkInferSpmd(
24+
const DistMetaTensor& x, int k, int axis, bool largest, bool sorted) {
25+
// Verify input args
26+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
27+
axis = axis < 0 ? x_ndim + axis : axis;
28+
PADDLE_ENFORCE_EQ(
29+
0 <= axis && axis < x_ndim,
30+
true,
31+
phi::errors::InvalidArgument(
32+
"The axis of topk should be in range [0, %d), but got %d.",
33+
x_ndim,
34+
axis));
35+
36+
// Create destination dist attrs
37+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
38+
TensorDistAttr out_dist_attr_dst =
39+
CopyTensorDistAttrForOutput(x_dist_attr_src);
40+
TensorDistAttr indices_dist_attr_dst =
41+
CopyTensorDistAttrForOutput(x_dist_attr_src);
42+
43+
// Infer dims_mapping
44+
std::vector<int64_t> x_dims_mapping_dst = x_dims_mapping_src;
45+
x_dims_mapping_dst[axis] = -1;
46+
std::vector<int64_t> out_dims_mapping_dst = x_dims_mapping_dst;
47+
std::vector<int64_t> indices_dims_mapping_dst = x_dims_mapping_dst;
48+
49+
// Set the dims mapping for outputs
50+
out_dist_attr_dst.set_dims_mapping(out_dims_mapping_dst);
51+
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);
52+
53+
// Update the dims mapping for inputs
54+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
55+
VLOG(4) << "TopkInferSpmd: Done.";
56+
LOG_SPMD_INPUT(x);
57+
LOG_SPMD_OUTPUT(out_dist_attr_dst);
58+
LOG_SPMD_OUTPUT(indices_dist_attr_dst);
59+
60+
return {{x_dist_attr_dst}, {out_dist_attr_dst, indices_dist_attr_dst}};
61+
}
62+
63+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
64+
const DistMetaTensor& indices,
65+
const DistMetaTensor& out_grad,
66+
int k,
67+
int axis,
68+
bool largest,
69+
bool sorted) {
70+
// Verify input args
71+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
72+
EXTRACT_SHAPE_AND_DIST_ATTR(indices);
73+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
74+
PADDLE_ENFORCE_EQ(indices_ndim,
75+
out_grad_ndim,
76+
common::errors::InvalidArgument(
77+
"TopKGrad: The rank of Indices [%d] and OutGrad [%d] "
78+
"must be the same.",
79+
indices_ndim,
80+
out_grad_ndim));
81+
PADDLE_ENFORCE_EQ(x_ndim,
82+
indices_ndim,
83+
common::errors::InvalidArgument(
84+
"TopKGrad: The rank of Input [%d] and Indices [%d] "
85+
"must be the same.",
86+
x_ndim,
87+
indices_ndim));
88+
axis = axis < 0 ? x_ndim + axis : axis;
89+
PADDLE_ENFORCE_EQ(
90+
0 <= axis && axis < x_ndim,
91+
true,
92+
phi::errors::InvalidArgument(
93+
"The axis of topk_grad should be in range [0, %d), but got %d.",
94+
x_ndim,
95+
axis));
96+
// Build einsum notation
97+
std::string alphabet = "abcdefghijlopqrstuvwxyz";
98+
std::string x_axes = alphabet.substr(0, x_ndim - 1);
99+
std::string indices_axes = x_axes;
100+
std::string out_grad_axes = x_axes;
101+
102+
// Merge sharding
103+
std::pair<std::string, std::vector<int64_t>> indices_pair(
104+
indices_axes, indices_dims_mapping_src);
105+
std::pair<std::string, std::vector<int64_t>> out_grad_pair(
106+
out_grad_axes, out_grad_dims_mapping_src);
107+
std::pair<std::string, std::vector<int64_t>> x_pair(x_axes,
108+
x_dims_mapping_src);
109+
auto axis_to_dim_map =
110+
ShardingMergeForTensors({x_pair, indices_pair, out_grad_pair});
111+
112+
// Infer dims mapping
113+
std::vector<int64_t> x_grad_dims_mapping_dst =
114+
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
115+
x_grad_dims_mapping_dst.insert(x_grad_dims_mapping_dst.begin() + axis, -1);
116+
std::vector<int64_t> x_dims_mapping_dst = x_grad_dims_mapping_dst;
117+
std::vector<int64_t> indices_dims_mapping_dst = x_grad_dims_mapping_dst;
118+
std::vector<int64_t> out_grad_dims_mapping_dst = x_grad_dims_mapping_dst;
119+
120+
// Set the dims mapping
121+
TensorDistAttr x_grad_dist_attr_dst =
122+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
123+
TensorDistAttr x_dist_attr_dst =
124+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
125+
TensorDistAttr indices_dist_attr_dst =
126+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
127+
TensorDistAttr out_grad_dist_attr_dst =
128+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
129+
130+
x_grad_dist_attr_dst.set_dims_mapping(x_grad_dims_mapping_dst);
131+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
132+
indices_dist_attr_dst.set_dims_mapping(indices_dims_mapping_dst);
133+
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_dst);
134+
135+
VLOG(4) << "TopkGradInferSpmd: Done.";
136+
LOG_SPMD_INPUT(x);
137+
LOG_SPMD_INPUT(indices);
138+
LOG_SPMD_INPUT(out_grad);
139+
LOG_SPMD_OUTPUT(x_grad_dist_attr_dst);
140+
141+
return {{x_dist_attr_dst, indices_dist_attr_dst, out_grad_dist_attr_dst},
142+
{x_grad_dist_attr_dst}};
143+
}
144+
SpmdInfo TopkInferSpmdDynamic(const DistMetaTensor& x,
145+
const Scalar& k,
146+
int axis,
147+
bool largest,
148+
bool sorted) {
149+
return TopkInferSpmd(x, k.to<int>(), axis, largest, sorted);
150+
}
151+
152+
SpmdInfo TopkGradInferSpmdDynamic(const DistMetaTensor& x,
153+
const DistMetaTensor& indices,
154+
const DistMetaTensor& out_grad,
155+
const Scalar& k,
156+
int axis,
157+
bool largest,
158+
bool sorted) {
159+
return TopkGradInferSpmd(
160+
x, indices, out_grad, k.to<int>(), axis, largest, sorted);
161+
}
162+
163+
} // namespace distributed
164+
} // namespace phi
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/scalar.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
19+
#include "paddle/phi/core/distributed/type_defs.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
24+
SpmdInfo TopkInferSpmd(
25+
const DistMetaTensor& x, int k, int axis, bool largest, bool sorted);
26+
27+
SpmdInfo TopkGradInferSpmd(const DistMetaTensor& x,
28+
const DistMetaTensor& indices,
29+
const DistMetaTensor& out_grad,
30+
int k,
31+
int axis,
32+
bool largest,
33+
bool sorted);
34+
35+
SpmdInfo TopkInferSpmdDynamic(const DistMetaTensor& x,
36+
const Scalar& k,
37+
int axis,
38+
bool largest,
39+
bool sorted);
40+
41+
SpmdInfo TopkGradInferSpmdDynamic(const DistMetaTensor& x,
42+
const DistMetaTensor& indices,
43+
const DistMetaTensor& out_grad,
44+
const Scalar& k,
45+
int axis,
46+
bool largest,
47+
bool sorted);
48+
49+
} // namespace distributed
50+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3519,6 +3519,7 @@
35193519
infer_meta :
35203520
func : UnchangedInferMeta
35213521
param : [x]
3522+
spmd_rule: TopkGradInferSpmdDynamic
35223523
kernel :
35233524
func : topk_grad
35243525
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5144,6 +5144,7 @@
51445144
output : Tensor(out), Tensor(indices)
51455145
infer_meta :
51465146
func : TopKInferMeta
5147+
spmd_rule: TopkInferSpmdDynamic
51475148
kernel :
51485149
func : topk
51495150
data_type : x

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ if(WITH_DISTRIBUTE)
4444
py_test_modules(test_logsumexp_rule MODULES test_logsumexp_rule)
4545
py_test_modules(test_nonzero_rule MODULES test_nonzero_rule)
4646
if(NOT WITH_ROCM)
47+
py_test_modules(test_topk_rule MODULES test_topk_rule)
4748
py_test_modules(test_add_n_rule MODULES test_add_n_rule)
4849
py_test_modules(test_mean_all_rule MODULES test_mean_all_rule)
4950
py_test_modules(test_argmin_rule MODULES test_argmin_rule)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from collections import OrderedDict
17+
18+
from paddle.distributed.auto_parallel.static.dist_attribute import (
19+
DistTensorSpec,
20+
TensorDistAttr,
21+
)
22+
from paddle.distributed.fleet import auto
23+
from paddle.framework import core
24+
25+
26+
class TestTopkSPMDRule(unittest.TestCase):
27+
def setUp(self):
28+
x_shape = [16, 16, 16]
29+
out_shape = [16, 2, 16]
30+
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
31+
32+
x_tensor_dist_attr = TensorDistAttr()
33+
x_tensor_dist_attr.dims_mapping = [-1, -1, -1]
34+
x_tensor_dist_attr.process_mesh = process_mesh
35+
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
36+
out_tensor_dist_attr = TensorDistAttr()
37+
out_tensor_dist_attr.dims_mapping = [-1, -1, -1]
38+
out_tensor_dist_attr.process_mesh = process_mesh
39+
self.out_dist_tensor_spec = DistTensorSpec(
40+
out_shape, x_tensor_dist_attr
41+
)
42+
43+
self.rule = core.get_phi_spmd_rule("topk")
44+
self.attrs = OrderedDict()
45+
self.attrs['k'] = 2
46+
self.attrs['axis'] = 1
47+
self.attrs['largest'] = True
48+
self.attrs['sorted'] = True
49+
50+
def test_topk_forward(self):
51+
# axis = 1
52+
# [0, 1, -1] --> [0, -1, -1], [0, -1, -1]
53+
self.attrs['axis'] = 1
54+
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
55+
result_dist_attrs = self.rule.infer_forward(
56+
self.x_dist_tensor_spec,
57+
self.attrs['k'],
58+
self.attrs['axis'],
59+
self.attrs['largest'],
60+
self.attrs['sorted'],
61+
)
62+
63+
self.assertEqual(len(result_dist_attrs), 2)
64+
inferred_input_dist_attrs = result_dist_attrs[0]
65+
inferred_output_dist_attrs = result_dist_attrs[1]
66+
67+
self.assertEqual(len(inferred_input_dist_attrs), 1)
68+
self.assertEqual(len(inferred_output_dist_attrs), 2)
69+
70+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
71+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, -1])
72+
self.assertEqual(
73+
inferred_output_dist_attrs[0].dims_mapping, [0, -1, -1]
74+
)
75+
76+
def test_topk_backward(self):
77+
# axis = 1
78+
# [0, -1, 1], [0, -1, 1], [-1, 1, -1] --> [0, -1, 1], [0, -1, 1], [0, -1, 1], [0, -1, 1]
79+
self.attrs['axis'] = 1
80+
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1])
81+
self.out_dist_tensor_spec.shape = [16, 2, 16]
82+
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, -1])
83+
result_dist_attrs = self.rule.infer_backward(
84+
self.x_dist_tensor_spec,
85+
self.out_dist_tensor_spec,
86+
self.out_dist_tensor_spec,
87+
self.attrs['k'],
88+
self.attrs['axis'],
89+
self.attrs['largest'],
90+
self.attrs['sorted'],
91+
)
92+
93+
self.assertEqual(len(result_dist_attrs), 2)
94+
inferred_input_dist_attrs = result_dist_attrs[0]
95+
inferred_output_dist_attrs = result_dist_attrs[1]
96+
self.assertEqual(len(inferred_input_dist_attrs), 3)
97+
self.assertEqual(len(inferred_output_dist_attrs), 1)
98+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
99+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
100+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [0, -1, 1])
101+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [0, -1, 1])
102+
103+
104+
if __name__ == "__main__":
105+
unittest.main()

0 commit comments

Comments
 (0)