Skip to content

Commit 9bb6a44

Browse files
jsawrukfacebook-github-bot
authored andcommitted
Fix Tensor/float Pyre type issues (meta-pytorch#1556)
Summary: Pull Request resolved: meta-pytorch#1556 Fix the following: ``` # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. ``` This is due to applying Python's math operators to floats and Tensors, resulting in type confusion. Use the equivalent torch functions instead, ex: `torch.pow` instead of `**` Reviewed By: cyrjano Differential Revision: D74025088 fbshipit-source-id: 326762bd83e3c8447c4eff83c292b448e46c051c
1 parent 54ccfcb commit 9bb6a44

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

captum/influence/_core/arnoldi_influence_function.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ def _parameter_arnoldi(
127127
H = torch.zeros(n + 1, n, dtype=next(iter(b)).dtype).to(device=projection_device)
128128
qs = [
129129
_parameter_to(
130-
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
131-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
132-
# `float`.
133-
_parameter_multiply(b, 1.0 / _parameter_dot(b, b) ** 0.5),
130+
_parameter_multiply(
131+
b, torch.div(1.0, torch.pow(_parameter_dot(b, b), 0.5))
132+
),
134133
device=projection_device,
135134
)
136135
]
@@ -148,14 +147,11 @@ def _parameter_arnoldi(
148147
for i in range(k):
149148
H[i, k - 1] = _parameter_dot(qs[i], v)
150149
v = _parameter_add(v, _parameter_multiply(qs[i], -H[i, k - 1]))
151-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `float`.
152-
H[k, k - 1] = _parameter_dot(v, v) ** 0.5
150+
H[k, k - 1] = torch.pow(_parameter_dot(v, v), 0.5)
153151

154152
if H[k, k - 1] < tol:
155153
break
156-
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
157-
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
158-
qs.append(_parameter_multiply(v, 1.0 / H[k, k - 1]))
154+
qs.append(_parameter_multiply(v, torch.div(1.0, H[k, k - 1])))
159155

160156
# pyre-fixme[61]: `k` is undefined, or not always defined.
161157
return qs[:k], H[:k, : k - 1]
@@ -657,8 +653,7 @@ def HVP(v):
657653
# however, since `vs` is instead a list of tuple of tensors, `R` should be
658654
# a list of tuple of tensors, where each entry in the list is scaled by the
659655
# corresponding entry in `ls ** 0.5`, which we first compute.
660-
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
661-
ls = (1.0 / ls) ** 0.5
656+
ls = torch.pow(torch.div(1.0, ls), 0.5)
662657

663658
# then, scale each entry in `vs` by the corresponding entry in `ls ** 0.5`
664659
# since each entry in `vs` is a tuple of tensors, we use a helper function

captum/influence/_utils/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,9 @@ def _parameter_add(
887887
return tuple(param_1 + param_2 for (param_1, param_2) in zip(params_1, params_2))
888888

889889

890-
def _parameter_multiply(params: Tuple[Tensor, ...], c: Tensor) -> Tuple[Tensor, ...]:
890+
def _parameter_multiply(
891+
params: Tuple[Tensor, ...], c: Union[Tensor, float]
892+
) -> Tuple[Tensor, ...]:
891893
"""
892894
multiplies all tensors in a tuple of tensors by a given scalar
893895
"""

0 commit comments

Comments
 (0)