There was an error while loading. Please reload this page.
1 parent 55c663f commit 263f91cCopy full SHA for 263f91c
distributed/FSDP/utils/train_utils.py
@@ -72,7 +72,7 @@ def validation(model, rank, world_size, val_loader):
72
model.eval()
73
correct = 0
74
local_rank = int(os.environ['LOCAL_RANK'])
75
- fsdp_loss = torch.zeros(3).to(local_rank)
+ fsdp_loss = torch.zeros(2).to(local_rank)
76
if rank == 0:
77
inner_pbar = tqdm.tqdm(
78
range(len(val_loader)), colour="green", desc="Validation Epoch"
0 commit comments