Skip to content

Commit 331eb6a

Browse files
aobo-ymeta-codesync[bot]
authored andcommitted
fix issue of long tensor in torch gather in interpretable input (#1678)
Summary: Pull Request resolved: #1678 `torch.gather` requires `index` as `long` in the released versions of pytorch `int32` is already supported in source code (why the internal tests pass) this diff gonna fix [the CI fails](https://github.com/meta-pytorch/captum/actions/runs/19831857827/job/56826961913), which is caused by test cases using int32 `mask` Reviewed By: jimshao1999 Differential Revision: D88101395 fbshipit-source-id: fc37972c0437eff03abc5250e53a59ad7fc02317
1 parent dd04b82 commit 331eb6a

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

captum/attr/_utils/interpretable_input.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def _scatter_itp_attr_by_mask(
5555
else:
5656
expanded_itp_attr = itp_attr
5757

58+
# gather index must be long
59+
# may support int32 soon https://github.com/pytorch/pytorch/pull/151822
60+
if expanded_feature_indices.dtype != torch.long:
61+
expanded_feature_indices = expanded_feature_indices.long()
62+
5863
# gather from (*output_dims, *inp.shape[1:-1], n_itp_features)
5964
attr = torch.gather(expanded_itp_attr, -1, expanded_feature_indices)
6065

0 commit comments

Comments
 (0)