There was an error while loading. Please reload this page.
1 parent 78c3235 commit e3173f8Copy full SHA for e3173f8
common/log_weights.py
@@ -74,7 +74,7 @@ def log_ebc_norms(
74
: min(sample_size, emb_weight_tensor.shape[0])
75
]
76
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
77
- # Change sample_size if the you observe frequent OOM errors or remove weight logging.
+ # Change sample_size if the user observe frequent OOM errors or remove weight logging.
78
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
79
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
80
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
0 commit comments