Skip to content

Commit 830de5b

Browse files
committed
avoid .item sync
1 parent ab6e300 commit 830de5b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _cal_hg_dynamic(
370370
# n_edge x e_dim
371371
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
372372
# n_edge x 3 x e_dim
373-
flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape(
373+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
374374
-1, 3 * e_dim
375375
)
376376
# nf x nloc x 3 x e_dim
@@ -745,12 +745,14 @@ def forward(
745745
nb, nloc, nnei = nlist.shape
746746
nall = node_ebd_ext.shape[1]
747747
node_ebd = node_ebd_ext[:, :nloc, :]
748-
n_edge = int(nlist_mask.sum().item())
749748
assert (nb, nloc) == node_ebd.shape[:2]
750749
if not self.use_dynamic_sel:
751750
assert (nb, nloc, nnei, 3) == h2.shape
751+
n_edge = None
752752
else:
753-
assert (n_edge, 3) == h2.shape
753+
# n_edge = int(nlist_mask.sum().item())
754+
# assert (n_edge, 3) == h2.shape
755+
n_edge = h2.shape[0]
754756
del a_nlist # may be used in the future
755757

756758
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]

0 commit comments

Comments
 (0)