Skip to content

Commit 2a6fcaf

Browse files
committed
Fix save_reduce issue on local machine
No idea why this fixes it.
1 parent ed17172 commit 2a6fcaf

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

reptile/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from utils import ParamDict as P
1919

2020
Weights = P
21-
criterion = F.mse_loss
21+
criterion = F.l1_loss
2222

2323
CUDA_AVAILABLE = torch.cuda.is_available()
2424

@@ -93,7 +93,8 @@ def train_batch(x: Tensor, y: Tensor, model: Model, opt) -> None:
9393
"""Statefully train model on single batch."""
9494
x, y = cuda(Variable(x)), cuda(Variable(y))
9595

96-
loss = criterion(model(x), y)
96+
# TODO figure out why ray breaks if I just declare criterion at the top.
97+
loss = F.mse_loss(model(x), y)
9798

9899
opt.zero_grad()
99100
loss.backward()

0 commit comments

Comments
 (0)