Skip to content

Commit f5025e6

Browse files
committed
add max_natoms_per_batch option to build_dataloader function
1 parent bcc3ee9 commit f5025e6

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

src/mattersim/datasets/utils/build.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,28 @@ def build_dataloader(
2323
num_workers: int = 0,
2424
pin_memory: bool = False,
2525
batch_converter: bool = True,
26+
max_natoms_per_batch: int = 4096,
2627
):
2728
"""
2829
Build a dataloader given a list of atoms
30+
31+
Args:
32+
atoms: list of Atoms objects
33+
energies: list of energies corresponding to the atoms
34+
forces: list of forces corresponding to the atoms
35+
stresses: list of stresses corresponding to the atoms
36+
cutoff: cutoff distance for graph construction
37+
threebody_cutoff: cutoff distance for three-body interactions
38+
batch_size: number of samples per batch
39+
model_type: type of model to use
40+
shuffle: whether to shuffle the data
41+
only_inference: whether to only perform inference
42+
num_workers: number of worker processes for data loading
43+
pin_memory: whether to pin memory
44+
batch_converter: whether to use batch converter
45+
max_natoms_per_batch: maximum number of atoms per batch in the batch converter, only used if batch_converter is True
46+
Do not confuse with batch_size, which is the number of samples per batch in the final dataloader.
47+
But, max_natoms_per_batch is used to control the number of atoms to construct the graph.
2948
"""
3049

3150
if not batch_converter:
@@ -76,7 +95,7 @@ def build_dataloader(
7695
energy=energies,
7796
forces=forces,
7897
stresses=stresses,
79-
max_natoms_per_batch=4096,
98+
max_natoms_per_batch=max_natoms_per_batch,
8099
)
81100
else:
82101
raise NotImplementedError(f"model type not supported: {model_type}")

0 commit comments

Comments
 (0)