11"""
2- Code derived from the OCP codebase:
3- https://github.com/Open-Catalyst-Project/ocp
4-
5- Copyright (c) Facebook, Inc. and its affiliates.
6-
7- This source code is licensed under the MIT license found in
8- https://github.com/Open-Catalyst-Project/ocp/blob/main/LICENSE.md.
2+ Code modified from the OCP codebase: https://github.com/Open-Catalyst-Project/ocp
93"""
104
115import sys
@@ -21,7 +15,6 @@ def radius_graph_pbc(
2115 natoms : torch .Tensor ,
2216 cell : torch .Tensor ,
2317 radius : float ,
24- max_num_neighbors_threshold : int ,
2518 max_cell_images_per_dim : int = sys .maxsize ,
2619) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
2720 """Function computing the graph in periodic boundary conditions on a (batched) set of
@@ -41,7 +34,6 @@ def radius_graph_pbc(
4134 cell (Tensor): atomic cell. Has shape
4235 :obj:`[n_structures, 3, 3]`
4336 radius (float): cutoff radius distance
44- max_num_neighbors_threshold (int): Maximum number of neighbours to consider.
4537
4638 Returns:
4739 edge_index (IntTensor): index of atoms in edges. Has shape
@@ -194,22 +186,13 @@ def radius_graph_pbc(
194186 cell_offsets = cell_offsets .view (- 1 , 3 )
195187 atom_distance_squared = torch .masked_select (atom_distance_squared , mask )
196188
197- mask_num_neighbors , num_neighbors_image = get_max_neighbors_mask (
198- natoms = natoms ,
199- index = index1 ,
200- atom_distance_squared = atom_distance_squared ,
201- max_num_neighbors_threshold = max_num_neighbors_threshold ,
202- )
203-
204- if not torch .all (mask_num_neighbors ):
205- # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
206- index1 = torch .masked_select (index1 , mask_num_neighbors )
207- index2 = torch .masked_select (index2 , mask_num_neighbors )
208- atom_distance_squared = torch .masked_select (atom_distance_squared , mask_num_neighbors )
209- cell_offsets = torch .masked_select (
210- cell_offsets .view (- 1 , 3 ), mask_num_neighbors .view (- 1 , 1 ).expand (- 1 , 3 )
211- )
212- cell_offsets = cell_offsets .view (- 1 , 3 )
189+ # Compute num_neighbors_image (number of edges per structure in the batch)
190+ num_atoms = natoms .sum ()
191+ ones = index1 .new_ones (1 ).expand_as (index1 )
192+ num_neighbors = segment_coo (ones , index1 , dim_size = num_atoms )
193+ image_indptr = torch .zeros (natoms .shape [0 ] + 1 , device = device , dtype = torch .long )
194+ image_indptr [1 :] = torch .cumsum (natoms , dim = 0 )
195+ num_neighbors_image = segment_csr (num_neighbors , image_indptr )
213196
214197 edge_index = torch .stack ((index2 , index1 ))
215198 # shifts = -torch.matmul(unit_cell, data.cell).view(-1, 3)
@@ -224,74 +207,3 @@ def radius_graph_pbc(
224207 torch .sqrt (atom_distance_squared ),
225208 )
226209
227-
228- def get_max_neighbors_mask (
229- natoms : torch .Tensor ,
230- index : torch .Tensor ,
231- atom_distance_squared : torch .Tensor ,
232- max_num_neighbors_threshold : int ,
233- ) -> tuple [torch .Tensor , torch .Tensor ]:
234- """
235- Give a mask that filters out edges so that each atom has at most
236- `max_num_neighbors_threshold` neighbors.
237- Assumes that `index` is sorted.
238- """
239- device = natoms .device
240- num_atoms = natoms .sum ()
241-
242- # Get number of neighbors
243- # segment_coo assumes sorted index
244- ones = index .new_ones (1 ).expand_as (index )
245- num_neighbors = segment_coo (ones , index , dim_size = num_atoms )
246- max_num_neighbors = num_neighbors .max ()
247- num_neighbors_thresholded = num_neighbors .clamp (max = max_num_neighbors_threshold )
248-
249- # Get number of (thresholded) neighbors per image
250- image_indptr = torch .zeros (natoms .shape [0 ] + 1 , device = device , dtype = torch .long )
251- image_indptr [1 :] = torch .cumsum (natoms , dim = 0 )
252- num_neighbors_image = segment_csr (num_neighbors_thresholded , image_indptr )
253-
254- # If max_num_neighbors is below the threshold, return early
255- if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0 :
256- mask_num_neighbors = torch .tensor ([True ], dtype = torch .bool , device = device ).expand_as (index )
257- return mask_num_neighbors , num_neighbors_image
258-
259- # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
260- # Fill with infinity so we can easily remove unused distances later.
261- distance_sort = torch .full (
262- [int ((num_atoms * max_num_neighbors ).long ().item ())], np .inf , device = device
263- )
264-
265- # Create an index map to map distances from atom_distance to distance_sort
266- # index_sort_map assumes index to be sorted
267- index_neighbor_offset = torch .cumsum (num_neighbors , dim = 0 ) - num_neighbors
268- index_neighbor_offset_expand = torch .repeat_interleave (index_neighbor_offset , num_neighbors )
269- index_sort_map = (
270- index * max_num_neighbors
271- + torch .arange (len (index ), device = device )
272- - index_neighbor_offset_expand
273- )
274- distance_sort .index_copy_ (0 , index_sort_map , atom_distance_squared )
275- distance_sort = distance_sort .view (num_atoms , max_num_neighbors )
276-
277- # Sort neighboring atoms based on distance
278- distance_sort , index_sort = torch .sort (distance_sort , dim = 1 )
279- # Select the max_num_neighbors_threshold neighbors that are closest
280- distance_sort = distance_sort [:, :max_num_neighbors_threshold ]
281- index_sort = index_sort [:, :max_num_neighbors_threshold ]
282-
283- # Offset index_sort so that it indexes into index
284- index_sort = index_sort + index_neighbor_offset .view (- 1 , 1 ).expand (
285- - 1 , max_num_neighbors_threshold
286- )
287- # Remove "unused pairs" with infinite distances
288- mask_finite = torch .isfinite (distance_sort )
289- index_sort = torch .masked_select (index_sort , mask_finite )
290-
291- # At this point index_sort contains the index into index of the
292- # closest max_num_neighbors_threshold neighbors per atom
293- # Create a mask to remove all pairs not in index_sort
294- mask_num_neighbors = torch .zeros (len (index ), device = device , dtype = torch .bool )
295- mask_num_neighbors .index_fill_ (0 , index_sort , torch .tensor (True ))
296-
297- return mask_num_neighbors , num_neighbors_image
0 commit comments