@@ -23,9 +23,28 @@ def build_dataloader(
2323 num_workers : int = 0 ,
2424 pin_memory : bool = False ,
2525 batch_converter : bool = True ,
26+ max_natoms_per_batch : int = 4096 ,
2627):
2728 """
2829 Build a dataloader given a list of atoms
30+
31+ Args:
32+ atoms: list of Atoms objects
33+ energies: list of energies corresponding to the atoms
34+ forces: list of forces corresponding to the atoms
35+ stresses: list of stresses corresponding to the atoms
36+ cutoff: cutoff distance for graph construction
37+ threebody_cutoff: cutoff distance for three-body interactions
38+ batch_size: number of samples per batch
39+ model_type: type of model to use
40+ shuffle: whether to shuffle the data
41+ only_inference: whether to only perform inference
42+ num_workers: number of worker processes for data loading
43+ pin_memory: whether to pin memory
44+ batch_converter: whether to use batch converter
45+ max_natoms_per_batch: maximum number of atoms per batch in the batch converter, only used if batch_converter is True
46+ Do not confuse with batch_size, which is the number of samples per batch in the final dataloader.
47+ But, max_natoms_per_batch is used to control the number of atoms to construct the graph.
2948 """
3049
3150 if not batch_converter :
@@ -76,7 +95,7 @@ def build_dataloader(
7695 energy = energies ,
7796 forces = forces ,
7897 stresses = stresses ,
79- max_natoms_per_batch = 4096 ,
98+ max_natoms_per_batch = max_natoms_per_batch ,
8099 )
81100 else :
82101 raise NotImplementedError (f"model type not supported: { model_type } " )
0 commit comments