Skip to content

Commit ab6e300

Browse files
caic99pre-commit-ci[bot]Copilot
authored
perf: use torch.topk to construct nlist (#4751)
This pull request updates the `build_neighbor_list` function in `deepmd/pt/utils/nlist.py` to improve its performance and correctness when selecting neighbors. The key change involves replacing `torch.sort` with `torch.topk` to optimize the selection of the nearest neighbors. <details><summary>Details</summary> <p> Before: 16.4ms <img width="495" alt="image" src="https://github.com/user-attachments/assets/e6b0c091-b11b-491a-b2fd-e0aadc6c35ba" /> After: 3.4ms <img width="303" alt="image" src="https://github.com/user-attachments/assets/81931c76-60eb-46a5-9934-0d837b954df9" /> Step time goes from 212.5ms to 200.5ms, ~5% speed-up. </p> </details> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Enhanced efficiency and clarity in neighbor selection, improving performance in identifying closest neighbors while maintaining the same user experience. - **Tests** - Improved reliability of neighbor list tests by validating neighbor sets regardless of order. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]>
1 parent 5b1bbc2 commit ab6e300

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

deepmd/pt/utils/nlist.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,33 +93,43 @@ def build_neighbor_list(
9393
9494
"""
9595
batch_size = coord.shape[0]
96-
coord = coord.view(batch_size, -1)
9796
nall = coord.shape[1] // 3
9897
# fill virtual atoms with large coords so they are not neighbors of any
9998
# real atom.
10099
if coord.numel() > 0:
101100
xmax = torch.max(coord) + 2.0 * rcut
102101
else:
103102
xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut
103+
104+
coord_xyz = coord.view(batch_size, nall, 3)
104105
# nf x nall
105106
is_vir = atype < 0
106-
coord1 = torch.where(
107-
is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
108-
).view(batch_size, nall * 3)
107+
# batch_size x nall x 3
108+
vcoord_xyz = torch.where(is_vir[:, :, None], xmax, coord_xyz)
109109
if isinstance(sel, int):
110110
sel = [sel]
111-
# nloc x 3
112-
coord0 = coord1[:, : nloc * 3]
113-
# nloc x nall x 3
114-
diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view(
115-
[batch_size, -1, 3]
116-
).unsqueeze(2)
117-
assert list(diff.shape) == [batch_size, nloc, nall, 3]
111+
112+
# Get the coordinates for the local atoms (first nloc atoms)
113+
# batch_size x nloc x 3
114+
vcoord_local_xyz = vcoord_xyz[:, :nloc, :]
115+
116+
# Calculate displacement vectors.
117+
diff = vcoord_xyz.unsqueeze(1) - vcoord_local_xyz.unsqueeze(2)
118+
assert diff.shape == (batch_size, nloc, nall, 3)
118119
# nloc x nall
119120
rr = torch.linalg.norm(diff, dim=-1)
120121
# if central atom has two zero distances, sorting sometimes can not exclude itself
121-
rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
122-
rr, nlist = torch.sort(rr, dim=-1)
122+
# The following operation makes rr[b, i, i] = -1.0 (assuming original self-distance is 0)
123+
# so that self-atom is sorted first.
124+
diag_len = min(nloc, nall)
125+
idx = torch.arange(diag_len, device=rr.device, dtype=torch.int)
126+
rr[:, idx, idx] -= 1.0
127+
128+
nsel = sum(sel)
129+
nnei = rr.shape[-1]
130+
top_k = min(nsel + 1, nnei)
131+
rr, nlist = torch.topk(rr, top_k, largest=False)
132+
123133
# nloc x (nall-1)
124134
rr = rr[:, :, 1:]
125135
nlist = nlist[:, :, 1:]

source/tests/pt/model/test_nlist.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,17 @@ def test_build_multiple_nlist(self) -> None:
152152
nlists[get_multiple_nlist_key(rcuts[dd], nsels[dd])].shape[-1],
153153
nsels[dd],
154154
)
155+
156+
# Since the nlist is created using unstable sort,
157+
# we check if the set of indices in the nlist matches,
158+
# regardless of the order
155159
torch.testing.assert_close(
156-
nlists[get_multiple_nlist_key(rcuts[0], nsels[0])],
157-
nlist0,
160+
nlists[get_multiple_nlist_key(rcuts[0], nsels[0])].sort(dim=-1).values,
161+
nlist0.sort(dim=-1).values,
158162
)
159163
torch.testing.assert_close(
160-
nlists[get_multiple_nlist_key(rcuts[1], nsels[1])],
161-
nlist2,
164+
nlists[get_multiple_nlist_key(rcuts[1], nsels[1])].sort(dim=-1).values,
165+
nlist2.sort(dim=-1).values,
162166
)
163167

164168
def test_extend_coord(self) -> None:

0 commit comments

Comments
 (0)