Skip to content

Commit 880a6f2

Browse files
committed
add support in python api
1 parent c2771ec commit 880a6f2

File tree

1 file changed

+12
-8
lines changed
  • python/paddle/incubate/tensor

1 file changed

+12
-8
lines changed

python/paddle/incubate/tensor/math.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None):
2929
where sum is over j such that `segment_ids[j] == i`.
3030
3131
Args:
32-
data (Tensor): A tensor, available data type float32, float64.
32+
data (Tensor): A tensor, available data type float32, float64, int32, int64.
3333
segment_ids (Tensor): A 1-D tensor, which have the same size
3434
with the first dimension of input data.
3535
Available data type is int32, int64.
@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None):
5454
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
5555
return out
5656

57-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
57+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
58+
"int64"), "segment_pool")
5859
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
5960
"segment_pool")
6061

@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None):
8283
of all index 'segment_ids[j] == i'.
8384
8485
Args:
85-
data (tensor): a tensor, available data type float32, float64.
86+
data (tensor): a tensor, available data type float32, float64, int32, int64.
8687
segment_ids (tensor): a 1-d tensor, which have the same size
8788
with the first dimension of input data.
8889
available data type is int32, int64.
@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None):
107108
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
108109
return out
109110

110-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
111+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
112+
"int64"), "segment_pool")
111113
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
112114
"segment_pool")
113115

@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None):
134136
where min is over j such that `segment_ids[j] == i`.
135137
136138
Args:
137-
data (tensor): a tensor, available data type float32, float64.
139+
data (tensor): a tensor, available data type float32, float64, int32, int64.
138140
segment_ids (tensor): a 1-d tensor, which have the same size
139141
with the first dimension of input data.
140142
available data type is int32, int64.
@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None):
159161
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
160162
return out
161163

162-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
164+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
165+
"int64"), "segment_pool")
163166
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
164167
"segment_pool")
165168

@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None):
186189
where max is over j such that `segment_ids[j] == i`.
187190
188191
Args:
189-
data (tensor): a tensor, available data type float32, float64.
192+
data (tensor): a tensor, available data type float32, float64, int32, int64.
190193
segment_ids (tensor): a 1-d tensor, which have the same size
191194
with the first dimension of input data.
192195
available data type is int32, int64.
@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None):
211214
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
212215
return out
213216

214-
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
217+
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
218+
"int64"), "segment_pool")
215219
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
216220
"segment_pool")
217221

0 commit comments

Comments
 (0)