Skip to content

Commit 6849d33

Browse files
authored
[Ops] segment pool op support for int int64 kernel. (#40577)
* segment pool support for int int64 kernel. * add support in python api
1 parent 2dec25d commit 6849d33

File tree

7 files changed

+52
-14
lines changed

7 files changed

+52
-14
lines changed

paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
2323
ALL_LAYOUT,
2424
phi::SegmentPoolGradKernel,
2525
float,
26-
double) {}
26+
double,
27+
int,
28+
int64_t) {}

paddle/phi/kernels/cpu/segment_pool_kernel.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/backends/cpu/cpu_context.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020

21-
PD_REGISTER_KERNEL(
22-
segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
21+
PD_REGISTER_KERNEL(segment_pool,
22+
CPU,
23+
ALL_LAYOUT,
24+
phi::SegmentPoolKernel,
25+
float,
26+
double,
27+
int,
28+
int64_t) {}

paddle/phi/kernels/funcs/segment_pooling.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,19 @@ template class SegmentPoolFunctor<CPU, float, int>;
149149
template class SegmentPoolFunctor<CPU, float, int64_t>;
150150
template class SegmentPoolFunctor<CPU, double, int>;
151151
template class SegmentPoolFunctor<CPU, double, int64_t>;
152+
template class SegmentPoolFunctor<CPU, int, int>;
153+
template class SegmentPoolFunctor<CPU, int, int64_t>;
154+
template class SegmentPoolFunctor<CPU, int64_t, int>;
155+
template class SegmentPoolFunctor<CPU, int64_t, int64_t>;
156+
152157
template class SegmentPoolGradFunctor<CPU, float, int>;
153158
template class SegmentPoolGradFunctor<CPU, float, int64_t>;
154159
template class SegmentPoolGradFunctor<CPU, double, int>;
155160
template class SegmentPoolGradFunctor<CPU, double, int64_t>;
161+
template class SegmentPoolGradFunctor<CPU, int, int>;
162+
template class SegmentPoolGradFunctor<CPU, int, int64_t>;
163+
template class SegmentPoolGradFunctor<CPU, int64_t, int>;
164+
template class SegmentPoolGradFunctor<CPU, int64_t, int64_t>;
156165

157166
} // namespace funcs
158167
} // namespace phi

paddle/phi/kernels/funcs/segment_pooling.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,19 @@ template class SegmentPoolFunctor<GPU, float, int>;
453453
template class SegmentPoolFunctor<GPU, float, int64_t>;
454454
template class SegmentPoolFunctor<GPU, double, int>;
455455
template class SegmentPoolFunctor<GPU, double, int64_t>;
456+
template class SegmentPoolFunctor<GPU, int, int>;
457+
template class SegmentPoolFunctor<GPU, int, int64_t>;
458+
template class SegmentPoolFunctor<GPU, int64_t, int>;
459+
template class SegmentPoolFunctor<GPU, int64_t, int64_t>;
460+
456461
template class SegmentPoolGradFunctor<GPU, float, int>;
457462
template class SegmentPoolGradFunctor<GPU, float, int64_t>;
458463
template class SegmentPoolGradFunctor<GPU, double, int>;
459464
template class SegmentPoolGradFunctor<GPU, double, int64_t>;
465+
template class SegmentPoolGradFunctor<GPU, int, int>;
466+
template class SegmentPoolGradFunctor<GPU, int, int64_t>;
467+
template class SegmentPoolGradFunctor<GPU, int64_t, int>;
468+
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
460469

461470
} // namespace funcs
462471
} // namespace phi

paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
2424
ALL_LAYOUT,
2525
phi::SegmentPoolGradKernel,
2626
float,
27-
double) {}
27+
double,
28+
int,
29+
int64_t) {}

paddle/phi/kernels/gpu/segment_pool_kernel.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,11 @@
1919
#include "paddle/phi/core/dense_tensor.h"
2020
#include "paddle/phi/core/kernel_registry.h"
2121

22-
PD_REGISTER_KERNEL(
23-
segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
22+
PD_REGISTER_KERNEL(segment_pool,
23+
GPU,
24+
ALL_LAYOUT,
25+
phi::SegmentPoolKernel,
26+
float,
27+
double,
28+
int,
29+
int64_t) {}

python/paddle/incubate/tensor/math.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None):
2929
where sum is over j such that `segment_ids[j] == i`.
3030
3131
Args:
32-
data (Tensor): A tensor, available data type float32, float64.
32+
data (Tensor): A tensor, available data type float32, float64, int32, int64.
3333
segment_ids (Tensor): A 1-D tensor, which have the same size
3434
with the first dimension of input data.
3535
Available data type is int32, int64.
@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None):
5454
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
5555
return out
5656

57-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
57+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
58+
"int64"), "segment_pool")
5859
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
5960
"segment_pool")
6061

@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None):
8283
of all index 'segment_ids[j] == i'.
8384
8485
Args:
85-
data (tensor): a tensor, available data type float32, float64.
86+
data (tensor): a tensor, available data type float32, float64, int32, int64.
8687
segment_ids (tensor): a 1-d tensor, which have the same size
8788
with the first dimension of input data.
8889
available data type is int32, int64.
@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None):
107108
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
108109
return out
109110

110-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
111+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
112+
"int64"), "segment_pool")
111113
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
112114
"segment_pool")
113115

@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None):
134136
where min is over j such that `segment_ids[j] == i`.
135137
136138
Args:
137-
data (tensor): a tensor, available data type float32, float64.
139+
data (tensor): a tensor, available data type float32, float64, int32, int64.
138140
segment_ids (tensor): a 1-d tensor, which have the same size
139141
with the first dimension of input data.
140142
available data type is int32, int64.
@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None):
159161
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
160162
return out
161163

162-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
164+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
165+
"int64"), "segment_pool")
163166
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
164167
"segment_pool")
165168

@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None):
186189
where max is over j such that `segment_ids[j] == i`.
187190
188191
Args:
189-
data (tensor): a tensor, available data type float32, float64.
192+
data (tensor): a tensor, available data type float32, float64, int32, int64.
190193
segment_ids (tensor): a 1-d tensor, which have the same size
191194
with the first dimension of input data.
192195
available data type is int32, int64.
@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None):
211214
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
212215
return out
213216

214-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
217+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
218+
"int64"), "segment_pool")
215219
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
216220
"segment_pool")
217221

0 commit comments

Comments
 (0)