Skip to content

Commit 6474949

Browse files
committed
refine compute flattened_index
1 parent 7f08293 commit 6474949

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

python/paddle/tensor/linalg.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5681,15 +5681,17 @@ def __check_ranges(D, ranges):
56815681
), f"The length of ranges list must be {D * 2}\n"
56825682

56835683
def __compute_flattened_index(index_list, hist_shape):
5684-
strides = []
5685-
acc = 1
5686-
for size in reversed(hist_shape):
5687-
strides.insert(0, acc)
5688-
acc *= size
5689-
5690-
flattened_index = paddle.zeros_like(index_list[0])
5691-
for idx, stride in zip(index_list, strides):
5692-
flattened_index += idx * stride
5684+
strides = (
5685+
paddle.to_tensor(hist_shape[::-1])
5686+
.cumprod(dim=0)
5687+
.flip(0)[1:]
5688+
.tolist()
5689+
)
5690+
strides.append(1)
5691+
strides_tensor = paddle.to_tensor(strides)
5692+
5693+
stacked_indices = paddle.stack(index_list, axis=-1)
5694+
flattened_index = (stacked_indices * strides_tensor).sum(axis=-1)
56935695

56945696
return flattened_index
56955697

0 commit comments

Comments
 (0)