11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import cast , Tuple , Union
4+ from typing import Callable , cast , Optional , Tuple , Union
55
66import numpy as np
77import torch
8+ import torch .nn as nn
89from captum ._utils .gradient import (
910 apply_gradient_requirements ,
1011 undo_gradient_requirements ,
1112)
13+ from captum ._utils .typing import ModuleOrModuleList
1214from captum .attr ._utils .approximation_methods import approximation_parameters
1315from captum .attr ._utils .attribution import LayerAttribution
1416from captum .attr ._utils .common import _reshape_and_sum
2931
3032
3133class ConductanceReference (LayerAttribution ):
32- # pyre-fixme[2]: Parameter must be annotated.
33- def __init__ (self , forward_func , layer ) -> None :
34+ def __init__ (
35+ self , forward_func : Callable [..., Tensor ], layer : ModuleOrModuleList
36+ ) -> None :
3437 r"""
3538 Args
3639
@@ -42,21 +45,16 @@ def __init__(self, forward_func, layer) -> None:
4245
4346 def _conductance_grads (
4447 self ,
45- # pyre-fixme[2]: Parameter must be annotated.
46- forward_fn ,
47- # pyre-fixme[2]: Parameter must be annotated.
48- input ,
49- # pyre-fixme[2]: Parameter must be annotated.
50- target_ind = None ,
48+ forward_fn : Callable [..., Tensor ],
49+ input : Tensor ,
50+ target_ind : Optional [Tensor ] = None ,
5151 ) -> Tuple [Tensor , Tensor , int ]:
5252 with torch .autograd .set_grad_enabled (True ):
5353 # Set a forward hook on specified module and run forward pass to
5454 # get output tensor size.
5555 saved_tensor = None
5656
57- # pyre-fixme[3]: Return type must be annotated.
58- # pyre-fixme[2]: Parameter must be annotated.
59- def forward_hook (module , inp , out ):
57+ def forward_hook (module : nn .Module , inp : Tensor , out : Tensor ) -> None :
6058 nonlocal saved_tensor
6159 saved_tensor = out
6260
@@ -67,8 +65,8 @@ def forward_hook(module, inp, out):
6765 # The hidden layer tensor is assumed to have dimension (num_hidden, ...)
6866 # where the product of the dimensions >= 1 correspond to the total
6967 # number of hidden neurons in the layer.
70- layer_size = tuple (cast (Tensor , saved_tensor ).size ())[1 :]
71- layer_units = int (np .prod (layer_size ))
68+ layer_size : Tuple [ int , ...] = tuple (cast (Tensor , saved_tensor ).size ())[1 :]
69+ layer_units : int = int (np .prod (layer_size ))
7270
7371 # Remove unnecessary forward hook.
7472 hook .remove ()
@@ -77,11 +75,7 @@ def forward_hook(module, inp, out):
7775 # just the gradient of each hidden unit with respect to input.
7876 saved_grads = None
7977
80- # pyre-fixme[53]: Captured variable `layer_size` is not annotated.
81- # pyre-fixme[53]: Captured variable `layer_units` is not annotated.
82- # pyre-fixme[3]: Return type must be annotated.
83- # pyre-fixme[2]: Parameter must be annotated.
84- def backward_hook (grads ):
78+ def backward_hook (grads : Tensor ) -> Tensor :
8579 nonlocal saved_grads
8680 saved_grads = grads
8781 zero_mat = torch .zeros ((1 ,) + layer_size )
@@ -101,9 +95,9 @@ def backward_hook(grads):
10195 # tensor. Save backward hook in order to remove hook appropriately.
10296 back_hook = None
10397
104- # pyre-fixme[3]: Return type must be annotated.
105- # pyre-fixme[2]: Parameter must be annotated.
106- def forward_hook_register_back ( module , inp , out ) :
98+ def forward_hook_register_back (
99+ module : nn . Module , inp : Tensor , out : Tensor
100+ ) -> None :
107101 nonlocal back_hook
108102 back_hook = out .register_hook (backward_hook )
109103
@@ -132,11 +126,9 @@ def forward_hook_register_back(module, inp, out):
132126
133127 def attribute (
134128 self ,
135- # pyre-fixme[2]: Parameter must be annotated.
136- inputs ,
129+ inputs : Tensor ,
137130 baselines : Union [None , int , Tensor ] = None ,
138- # pyre-fixme[2]: Parameter must be annotated.
139- target = None ,
131+ target : Optional [Tensor ] = None ,
140132 n_steps : int = 500 ,
141133 method : str = "riemann_trapezoid" ,
142134 ) -> Tensor :
@@ -172,10 +164,6 @@ def attribute(
172164
173165 # compute scaled inputs from baseline to final input.
174166 scaled_features = torch .cat (
175- # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
176- # typing.Tuple[Tensor, ...]]` but got `List[float]`.
177- # pyre-fixme[58]: `+` is not supported for operand types `Union[int,
178- # torch._tensor.Tensor]` and `float`.
179167 [baselines + alpha * (inputs - baselines ) for alpha in alphas ],
180168 dim = 0 ,
181169 )
@@ -214,6 +202,5 @@ def attribute(
214202 scaled_grads .view (mid_layer_gradients .shape ) * summed_input_grads ,
215203 n_steps ,
216204 inputs .shape [0 ],
217- # pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
218- mid_layer_gradients .shape [1 :],
205+ tuple (mid_layer_gradients .shape [1 :]),
219206 )
0 commit comments