Skip to content

Commit 54ccfcb

Browse files
jsawrukfacebook-github-bot
authored andcommitted
Annotate Loss Function type in Arnoldi Influence (meta-pytorch#1555)
Summary: Pull Request resolved: meta-pytorch#1555 title The `Callable` type for a loss function appears to be `[Tensor, Tensor] -> [Tensor]`. See https://fburl.com/code/ieym5p0c for invocation Reviewed By: cyrjano Differential Revision: D74016988 fbshipit-source-id: bca74ea9320300ce32c02a4e4c1ff35fa4ec6717
1 parent eb75534 commit 54ccfcb

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

captum/influence/_core/arnoldi_influence_function.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,12 @@ def __init__(
304304
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
305305
checkpoints_load_func: Callable = _load_flexible_state_dict,
306306
layers: Optional[List[str]] = None,
307-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
308-
loss_fn: Optional[Union[Module, Callable]] = None,
307+
loss_fn: Optional[Union[Module, Callable[[Tensor, Tensor], Tensor]]] = None,
309308
batch_size: Union[int, None] = 1,
310309
hessian_dataset: Optional[Union[Dataset, DataLoader]] = None,
311-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
312-
test_loss_fn: Optional[Union[Module, Callable]] = None,
310+
test_loss_fn: Optional[
311+
Union[Module, Callable[[Tensor, Tensor], Tensor]]
312+
] = None,
313313
sample_wise_grads_per_batch: bool = False,
314314
projection_dim: int = 50,
315315
seed: int = 0,
@@ -755,11 +755,12 @@ def compute_intermediate_quantities(
755755
return_device = torch.device("cpu") if return_on_cpu else self.model_device
756756

757757
# choose the correct loss function and reduction type based on `test`
758-
loss_fn = self.test_loss_fn if test else self.loss_fn
758+
loss_fn: Optional[Union[Module, Callable[[Tensor, Tensor], Tensor]]] = (
759+
self.test_loss_fn if test else self.loss_fn
760+
)
759761
reduction_type = self.test_reduction_type if test else self.reduction_type
760762

761763
# define a helper function that returns the embeddings for a batch
762-
# pyre-fixme[53]: Captured variable `loss_fn` is not annotated.
763764
# pyre-fixme[53]: Captured variable `reduction_type` is not annotated.
764765
# pyre-fixme[3]: Return type must be annotated.
765766
# pyre-fixme[2]: Parameter must be annotated.

0 commit comments

Comments
 (0)