Skip to content

Commit ba9f02f

Browse files
authored
fix(pt): fix zero inputs for LayerNorm (#4134)
Fix #4064. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved robustness of layer normalization by handling empty input tensors, ensuring consistent output without errors. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 6976fb7 commit ba9f02f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

deepmd/pt/model/network/layernorm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,11 @@ def forward(
9696
# variance = xx.var(dim=-1, unbiased=False, keepdim=True)
9797
# The following operation is the same as above, but will not raise error when using jit model to inference.
9898
# See https://github.com/pytorch/pytorch/issues/85792
99-
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
100-
yy = (xx - mean) / torch.sqrt(variance + self.eps)
99+
if xx.numel() > 0:
100+
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
101+
yy = (xx - mean) / torch.sqrt(variance + self.eps)
102+
else:
103+
yy = xx
101104
if self.matrix is not None and self.bias is not None:
102105
yy = yy * self.matrix + self.bias
103106
return yy

0 commit comments

Comments
 (0)