Skip to content

Commit e15b97b

Browse files
authored
fix(pt): fix zero inputs for LayerNorm
Fix #4064. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 96ed5df commit e15b97b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

deepmd/pt/model/network/layernorm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,11 @@ def forward(
9696
# variance = xx.var(dim=-1, unbiased=False, keepdim=True)
9797
# The following operation is the same as above, but will not raise error when using jit model to inference.
9898
# See https://github.com/pytorch/pytorch/issues/85792
99-
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
100-
yy = (xx - mean) / torch.sqrt(variance + self.eps)
99+
if xx.numel() > 0:
100+
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
101+
yy = (xx - mean) / torch.sqrt(variance + self.eps)
102+
else:
103+
yy = xx
101104
if self.matrix is not None and self.bias is not None:
102105
yy = yy * self.matrix + self.bias
103106
return yy

0 commit comments

Comments
 (0)