File tree Expand file tree Collapse file tree 1 file changed +11
-9
lines changed Expand file tree Collapse file tree 1 file changed +11
-9
lines changed Original file line number Diff line number Diff line change @@ -5681,15 +5681,17 @@ def __check_ranges(D, ranges):
56815681 ), f"The length of ranges list must be { D * 2 } \n "
56825682
56835683 def __compute_flattened_index (index_list , hist_shape ):
5684- strides = []
5685- acc = 1
5686- for size in reversed (hist_shape ):
5687- strides .insert (0 , acc )
5688- acc *= size
5689-
5690- flattened_index = paddle .zeros_like (index_list [0 ])
5691- for idx , stride in zip (index_list , strides ):
5692- flattened_index += idx * stride
5684+ strides = (
5685+ paddle .to_tensor (hist_shape [::- 1 ])
5686+ .cumprod (dim = 0 )
5687+ .flip (0 )[1 :]
5688+ .tolist ()
5689+ )
5690+ strides .append (1 )
5691+ strides_tensor = paddle .to_tensor (strides )
5692+
5693+ stacked_indices = paddle .stack (index_list , axis = - 1 )
5694+ flattened_index = (stacked_indices * strides_tensor ).sum (axis = - 1 )
56935695
56945696 return flattened_index
56955697
You can’t perform that action at this time.
0 commit comments