Skip to content

Commit 3f6d97f

Browse files
jsawrukfacebook-github-bot
authored andcommitted
Fix pyre issues in conductance_reference (meta-pytorch#1560)
Summary: Pull Request resolved: meta-pytorch#1560 Fix all pyre issues in conductance_reference.py by adding type annotations. Reviewed By: cyrjano Differential Revision: D74099818 fbshipit-source-id: 5c3ec8ede5eb4986ea16471863fb7393cab0e44b
1 parent 8d600ff commit 3f6d97f

File tree

1 file changed

+19
-32
lines changed

1 file changed

+19
-32
lines changed

captum/testing/attr/helpers/conductance_reference.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import cast, Tuple, Union
4+
from typing import Callable, cast, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
8+
import torch.nn as nn
89
from captum._utils.gradient import (
910
apply_gradient_requirements,
1011
undo_gradient_requirements,
1112
)
13+
from captum._utils.typing import ModuleOrModuleList
1214
from captum.attr._utils.approximation_methods import approximation_parameters
1315
from captum.attr._utils.attribution import LayerAttribution
1416
from captum.attr._utils.common import _reshape_and_sum
@@ -29,8 +31,9 @@
2931

3032

3133
class ConductanceReference(LayerAttribution):
32-
# pyre-fixme[2]: Parameter must be annotated.
33-
def __init__(self, forward_func, layer) -> None:
34+
def __init__(
35+
self, forward_func: Callable[..., Tensor], layer: ModuleOrModuleList
36+
) -> None:
3437
r"""
3538
Args
3639
@@ -42,21 +45,16 @@ def __init__(self, forward_func, layer) -> None:
4245

4346
def _conductance_grads(
4447
self,
45-
# pyre-fixme[2]: Parameter must be annotated.
46-
forward_fn,
47-
# pyre-fixme[2]: Parameter must be annotated.
48-
input,
49-
# pyre-fixme[2]: Parameter must be annotated.
50-
target_ind=None,
48+
forward_fn: Callable[..., Tensor],
49+
input: Tensor,
50+
target_ind: Optional[Tensor] = None,
5151
) -> Tuple[Tensor, Tensor, int]:
5252
with torch.autograd.set_grad_enabled(True):
5353
# Set a forward hook on specified module and run forward pass to
5454
# get output tensor size.
5555
saved_tensor = None
5656

57-
# pyre-fixme[3]: Return type must be annotated.
58-
# pyre-fixme[2]: Parameter must be annotated.
59-
def forward_hook(module, inp, out):
57+
def forward_hook(module: nn.Module, inp: Tensor, out: Tensor) -> None:
6058
nonlocal saved_tensor
6159
saved_tensor = out
6260

@@ -67,8 +65,8 @@ def forward_hook(module, inp, out):
6765
# The hidden layer tensor is assumed to have dimension (num_hidden, ...)
6866
# where the product of the dimensions >= 1 correspond to the total
6967
# number of hidden neurons in the layer.
70-
layer_size = tuple(cast(Tensor, saved_tensor).size())[1:]
71-
layer_units = int(np.prod(layer_size))
68+
layer_size: Tuple[int, ...] = tuple(cast(Tensor, saved_tensor).size())[1:]
69+
layer_units: int = int(np.prod(layer_size))
7270

7371
# Remove unnecessary forward hook.
7472
hook.remove()
@@ -77,11 +75,7 @@ def forward_hook(module, inp, out):
7775
# just the gradient of each hidden unit with respect to input.
7876
saved_grads = None
7977

80-
# pyre-fixme[53]: Captured variable `layer_size` is not annotated.
81-
# pyre-fixme[53]: Captured variable `layer_units` is not annotated.
82-
# pyre-fixme[3]: Return type must be annotated.
83-
# pyre-fixme[2]: Parameter must be annotated.
84-
def backward_hook(grads):
78+
def backward_hook(grads: Tensor) -> Tensor:
8579
nonlocal saved_grads
8680
saved_grads = grads
8781
zero_mat = torch.zeros((1,) + layer_size)
@@ -101,9 +95,9 @@ def backward_hook(grads):
10195
# tensor. Save backward hook in order to remove hook appropriately.
10296
back_hook = None
10397

104-
# pyre-fixme[3]: Return type must be annotated.
105-
# pyre-fixme[2]: Parameter must be annotated.
106-
def forward_hook_register_back(module, inp, out):
98+
def forward_hook_register_back(
99+
module: nn.Module, inp: Tensor, out: Tensor
100+
) -> None:
107101
nonlocal back_hook
108102
back_hook = out.register_hook(backward_hook)
109103

@@ -132,11 +126,9 @@ def forward_hook_register_back(module, inp, out):
132126

133127
def attribute(
134128
self,
135-
# pyre-fixme[2]: Parameter must be annotated.
136-
inputs,
129+
inputs: Tensor,
137130
baselines: Union[None, int, Tensor] = None,
138-
# pyre-fixme[2]: Parameter must be annotated.
139-
target=None,
131+
target: Optional[Tensor] = None,
140132
n_steps: int = 500,
141133
method: str = "riemann_trapezoid",
142134
) -> Tensor:
@@ -172,10 +164,6 @@ def attribute(
172164

173165
# compute scaled inputs from baseline to final input.
174166
scaled_features = torch.cat(
175-
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
176-
# typing.Tuple[Tensor, ...]]` but got `List[float]`.
177-
# pyre-fixme[58]: `+` is not supported for operand types `Union[int,
178-
# torch._tensor.Tensor]` and `float`.
179167
[baselines + alpha * (inputs - baselines) for alpha in alphas],
180168
dim=0,
181169
)
@@ -214,6 +202,5 @@ def attribute(
214202
scaled_grads.view(mid_layer_gradients.shape) * summed_input_grads,
215203
n_steps,
216204
inputs.shape[0],
217-
# pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
218-
mid_layer_gradients.shape[1:],
205+
tuple(mid_layer_gradients.shape[1:]),
219206
)

0 commit comments

Comments
 (0)