@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None):
2929 where sum is over j such that `segment_ids[j] == i`.
3030
3131 Args:
32- data (Tensor): A tensor, available data type float32, float64.
32+ data (Tensor): A tensor, available data type float32, float64, int32, int64 .
3333 segment_ids (Tensor): A 1-D tensor, which have the same size
3434 with the first dimension of input data.
3535 Available data type is int32, int64.
@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None):
5454 out , tmp = _C_ops .segment_pool (data , segment_ids , 'pooltype' , "SUM" )
5555 return out
5656
57- check_variable_and_dtype (data , "X" , ("float32" , "float64" ), "segment_pool" )
57+ check_variable_and_dtype (data , "X" , ("float32" , "float64" , "int32" ,
58+ "int64" ), "segment_pool" )
5859 check_variable_and_dtype (segment_ids , "SegmentIds" , ("int32" , "int64" ),
5960 "segment_pool" )
6061
@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None):
8283 of all index 'segment_ids[j] == i'.
8384
8485 Args:
85- data (tensor): a tensor, available data type float32, float64.
86+ data (tensor): a tensor, available data type float32, float64, int32, int64 .
8687 segment_ids (tensor): a 1-d tensor, which have the same size
8788 with the first dimension of input data.
8889 available data type is int32, int64.
@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None):
107108 out , tmp = _C_ops .segment_pool (data , segment_ids , 'pooltype' , "MEAN" )
108109 return out
109110
110- check_variable_and_dtype (data , "X" , ("float32" , "float64" ), "segment_pool" )
111+ check_variable_and_dtype (data , "X" , ("float32" , "float64" , "int32" ,
112+ "int64" ), "segment_pool" )
111113 check_variable_and_dtype (segment_ids , "SegmentIds" , ("int32" , "int64" ),
112114 "segment_pool" )
113115
@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None):
134136 where min is over j such that `segment_ids[j] == i`.
135137
136138 Args:
137- data (tensor): a tensor, available data type float32, float64.
139+ data (tensor): a tensor, available data type float32, float64, int32, int64 .
138140 segment_ids (tensor): a 1-d tensor, which have the same size
139141 with the first dimension of input data.
140142 available data type is int32, int64.
@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None):
159161 out , tmp = _C_ops .segment_pool (data , segment_ids , 'pooltype' , "MIN" )
160162 return out
161163
162- check_variable_and_dtype (data , "X" , ("float32" , "float64" ), "segment_pool" )
164+ check_variable_and_dtype (data , "X" , ("float32" , "float64" , "int32" ,
165+ "int64" ), "segment_pool" )
163166 check_variable_and_dtype (segment_ids , "SegmentIds" , ("int32" , "int64" ),
164167 "segment_pool" )
165168
@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None):
186189 where max is over j such that `segment_ids[j] == i`.
187190
188191 Args:
189- data (tensor): a tensor, available data type float32, float64.
192+ data (tensor): a tensor, available data type float32, float64, int32, int64 .
190193 segment_ids (tensor): a 1-d tensor, which have the same size
191194 with the first dimension of input data.
192195 available data type is int32, int64.
@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None):
211214 out , tmp = _C_ops .segment_pool (data , segment_ids , 'pooltype' , "MAX" )
212215 return out
213216
214- check_variable_and_dtype (data , "X" , ("float32" , "float64" ), "segment_pool" )
217+ check_variable_and_dtype (data , "X" , ("float32" , "float64" , "int32" ,
218+ "int64" ), "segment_pool" )
215219 check_variable_and_dtype (segment_ids , "SegmentIds" , ("int32" , "int64" ),
216220 "segment_pool" )
217221
0 commit comments