Skip to content

Commit c510815

Browse files
fix code for view ops
1 parent 8fe0df6 commit c510815

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

deepmd/pd/model/descriptor/repflows.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,10 @@ def forward(
604604
_shapes = node_ebd.shape[1:]
605605
_shapes[1] = n_padding
606606
node_ebd = paddle.concat(
607-
[node_ebd, paddle.zeros(_shapes, dtype=node_ebd.dtype)],
607+
[
608+
node_ebd.squeeze(0),
609+
paddle.zeros(_shapes, dtype=node_ebd.dtype),
610+
],
608611
axis=1,
609612
)
610613
real_nloc = nloc
@@ -652,7 +655,7 @@ def forward(
652655
place=paddle.CPUPlace(),
653656
), # should be int of c++, placed on cpu
654657
)
655-
node_ebd_ext = ret[0].unsqueeze(0)
658+
node_ebd_ext = paddle.assign(ret).unsqueeze(0)
656659
if has_spin:
657660
node_ebd_real_ext, node_ebd_virtual_ext = paddle.split(
658661
node_ebd_ext, [n_dim, n_dim], axis=2

deepmd/pd/model/descriptor/repformers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,7 @@ def forward(
536536
place=paddle.CPUPlace(),
537537
), # should be int of c++, placed on cpu
538538
)
539-
# print(f"ret.shape = {ret.shape}")
540-
# print(f"ret[0].shape = ", ret[0].shape)
541-
g1_ext = ret.unsqueeze(0)
542-
# print(f"g1_ext.shape = ", g1_ext.shape)
543-
# exit()
539+
g1_ext = paddle.assign(ret).unsqueeze(0)
544540
if has_spin:
545541
g1_real_ext, g1_virtual_ext = paddle.split(
546542
g1_ext, [ng1, ng1], dim=2

0 commit comments

Comments
 (0)