Skip to content

Commit 9484568

Browse files
committed
removed the max_num_neighbors_threshold parameter, we will not use it anyway
1 parent 7db21ee commit 9484568

File tree

3 files changed

+8
-105
lines changed

3 files changed

+8
-105
lines changed

src/mattersim/datasets/utils/converter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,12 @@ def __init__(
380380
twobody_cutoff: float = 5.0,
381381
has_threebody: bool = True,
382382
threebody_cutoff: float = 4.0,
383-
max_num_neighbors_threshold: int = int(1e6),
384383
device: str | torch.device | None = None,
385384
):
386385
self.model_type = model_type
387386
self.twobody_cutoff = twobody_cutoff
388387
self.threebody_cutoff = threebody_cutoff
389388
self.has_threebody = has_threebody
390-
self.max_num_neighbors_threshold = max_num_neighbors_threshold
391389
if device is None:
392390
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
393391
elif isinstance(device, str):
@@ -544,7 +542,6 @@ def convert(
544542
cell=cell,
545543
natoms=natoms,
546544
radius=self.twobody_cutoff,
547-
max_num_neighbors_threshold=self.max_num_neighbors_threshold,
548545
)
549546
edge_indices = torch.cat(
550547
(edge_indices[1].unsqueeze(0), edge_indices[0].unsqueeze(0)), dim=0

src/mattersim/datasets/utils/ocp_graph_utils.py

Lines changed: 8 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
"""
2-
Code derived from the OCP codebase:
3-
https://github.com/Open-Catalyst-Project/ocp
4-
5-
Copyright (c) Facebook, Inc. and its affiliates.
6-
7-
This source code is licensed under the MIT license found in
8-
https://github.com/Open-Catalyst-Project/ocp/blob/main/LICENSE.md.
2+
Code modified from the OCP codebase: https://github.com/Open-Catalyst-Project/ocp
93
"""
104

115
import sys
@@ -21,7 +15,6 @@ def radius_graph_pbc(
2115
natoms: torch.Tensor,
2216
cell: torch.Tensor,
2317
radius: float,
24-
max_num_neighbors_threshold: int,
2518
max_cell_images_per_dim: int = sys.maxsize,
2619
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2720
"""Function computing the graph in periodic boundary conditions on a (batched) set of
@@ -41,7 +34,6 @@ def radius_graph_pbc(
4134
cell (Tensor): atomic cell. Has shape
4235
:obj:`[n_structures, 3, 3]`
4336
radius (float): cutoff radius distance
44-
max_num_neighbors_threshold (int): Maximum number of neighbours to consider.
4537
4638
Returns:
4739
edge_index (IntTensor): index of atoms in edges. Has shape
@@ -194,22 +186,13 @@ def radius_graph_pbc(
194186
cell_offsets = cell_offsets.view(-1, 3)
195187
atom_distance_squared = torch.masked_select(atom_distance_squared, mask)
196188

197-
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
198-
natoms=natoms,
199-
index=index1,
200-
atom_distance_squared=atom_distance_squared,
201-
max_num_neighbors_threshold=max_num_neighbors_threshold,
202-
)
203-
204-
if not torch.all(mask_num_neighbors):
205-
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
206-
index1 = torch.masked_select(index1, mask_num_neighbors)
207-
index2 = torch.masked_select(index2, mask_num_neighbors)
208-
atom_distance_squared = torch.masked_select(atom_distance_squared, mask_num_neighbors)
209-
cell_offsets = torch.masked_select(
210-
cell_offsets.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
211-
)
212-
cell_offsets = cell_offsets.view(-1, 3)
189+
# Compute num_neighbors_image (number of edges per structure in the batch)
190+
num_atoms = natoms.sum()
191+
ones = index1.new_ones(1).expand_as(index1)
192+
num_neighbors = segment_coo(ones, index1, dim_size=num_atoms)
193+
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
194+
image_indptr[1:] = torch.cumsum(natoms, dim=0)
195+
num_neighbors_image = segment_csr(num_neighbors, image_indptr)
213196

214197
edge_index = torch.stack((index2, index1))
215198
# shifts = -torch.matmul(unit_cell, data.cell).view(-1, 3)
@@ -224,74 +207,3 @@ def radius_graph_pbc(
224207
torch.sqrt(atom_distance_squared),
225208
)
226209

227-
228-
def get_max_neighbors_mask(
229-
natoms: torch.Tensor,
230-
index: torch.Tensor,
231-
atom_distance_squared: torch.Tensor,
232-
max_num_neighbors_threshold: int,
233-
) -> tuple[torch.Tensor, torch.Tensor]:
234-
"""
235-
Give a mask that filters out edges so that each atom has at most
236-
`max_num_neighbors_threshold` neighbors.
237-
Assumes that `index` is sorted.
238-
"""
239-
device = natoms.device
240-
num_atoms = natoms.sum()
241-
242-
# Get number of neighbors
243-
# segment_coo assumes sorted index
244-
ones = index.new_ones(1).expand_as(index)
245-
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
246-
max_num_neighbors = num_neighbors.max()
247-
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)
248-
249-
# Get number of (thresholded) neighbors per image
250-
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
251-
image_indptr[1:] = torch.cumsum(natoms, dim=0)
252-
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
253-
254-
# If max_num_neighbors is below the threshold, return early
255-
if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:
256-
mask_num_neighbors = torch.tensor([True], dtype=torch.bool, device=device).expand_as(index)
257-
return mask_num_neighbors, num_neighbors_image
258-
259-
# Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
260-
# Fill with infinity so we can easily remove unused distances later.
261-
distance_sort = torch.full(
262-
[int((num_atoms * max_num_neighbors).long().item())], np.inf, device=device
263-
)
264-
265-
# Create an index map to map distances from atom_distance to distance_sort
266-
# index_sort_map assumes index to be sorted
267-
index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
268-
index_neighbor_offset_expand = torch.repeat_interleave(index_neighbor_offset, num_neighbors)
269-
index_sort_map = (
270-
index * max_num_neighbors
271-
+ torch.arange(len(index), device=device)
272-
- index_neighbor_offset_expand
273-
)
274-
distance_sort.index_copy_(0, index_sort_map, atom_distance_squared)
275-
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
276-
277-
# Sort neighboring atoms based on distance
278-
distance_sort, index_sort = torch.sort(distance_sort, dim=1)
279-
# Select the max_num_neighbors_threshold neighbors that are closest
280-
distance_sort = distance_sort[:, :max_num_neighbors_threshold]
281-
index_sort = index_sort[:, :max_num_neighbors_threshold]
282-
283-
# Offset index_sort so that it indexes into index
284-
index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(
285-
-1, max_num_neighbors_threshold
286-
)
287-
# Remove "unused pairs" with infinite distances
288-
mask_finite = torch.isfinite(distance_sort)
289-
index_sort = torch.masked_select(index_sort, mask_finite)
290-
291-
# At this point index_sort contains the index into index of the
292-
# closest max_num_neighbors_threshold neighbors per atom
293-
# Create a mask to remove all pairs not in index_sort
294-
mask_num_neighbors = torch.zeros(len(index), device=device, dtype=torch.bool)
295-
mask_num_neighbors.index_fill_(0, index_sort, torch.tensor(True))
296-
297-
return mask_num_neighbors, num_neighbors_image

tests/datasets/test_batch_graph_converter.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def test_default_initialization(self):
5656
assert converter.twobody_cutoff == 5.0
5757
assert converter.threebody_cutoff == 4.0
5858
assert converter.has_threebody is True
59-
assert converter.max_num_neighbors_threshold == int(1e6)
6059

6160
def test_custom_cutoffs(self):
6261
"""Test initialization with custom cutoff values."""
@@ -88,11 +87,6 @@ def test_unsupported_model_type(self):
8887
with pytest.raises(NotImplementedError):
8988
BatchGraphConverter(model_type="unsupported")
9089

91-
def test_max_num_neighbors_threshold(self):
92-
"""Test custom max_num_neighbors_threshold."""
93-
converter = BatchGraphConverter(max_num_neighbors_threshold=500)
94-
assert converter.max_num_neighbors_threshold == 500
95-
9690

9791
# =============================================================================
9892
# Convert Method Tests

0 commit comments

Comments
 (0)