Skip to content

Commit 0153da8

Browse files
committed
fix bug where output concatenation in batch_apply fails with states with different numbers of atoms
1 parent 0dfbb35 commit 0153da8

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

psiflow/data/dataset.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -415,14 +415,33 @@ def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]:
415415
Note:
416416
This function is wrapped as a Parsl app and executed using the default_threads executor.
417417
"""
418+
def pad_arrays(
419+
arrays: list[np.ndarray],
420+
pad_dimension: int = 1,
421+
) -> list[np.ndarray]:
422+
ndims = np.array([len(a.shape) for a in arrays])
423+
assert np.all(ndims == ndims[0])
424+
assert np.all(pad_dimension < ndims)
425+
426+
pad_size = max([a.shape[pad_dimension] for a in arrays])
427+
for i in range(len(arrays)):
428+
shape = list(arrays[i].shape)
429+
shape[pad_dimension] = pad_size - shape[pad_dimension]
430+
padding = np.zeros(tuple(shape)) + np.nan
431+
arrays[i] = np.concatenate((arrays[i], padding), axis=pad_dimension)
432+
return arrays
433+
418434
narrays = len(args[0])
419435
for arg in args:
420436
assert isinstance(arg, list)
421437
assert all([len(a) == narrays for a in args])
422438

423439
concatenated = []
424440
for i in range(narrays):
425-
concatenated.append(np.concatenate([arg[i] for arg in args]))
441+
arrays = [arg[i] for arg in args]
442+
if len(arrays[0].shape) > 1:
443+
pad_arrays(arrays)
444+
concatenated.append(np.concatenate(tuple(arrays)))
426445
return concatenated
427446

428447

tests/test_function.py

+7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from parsl.data_provider.files import File # type: ignore
66

77
import psiflow
8+
from psiflow.data import Dataset
89
from psiflow.functions import (
910
EinsteinCrystalFunction,
1011
HarmonicFunction,
@@ -254,6 +255,12 @@ def test_hamiltonian_arithmetic(dataset):
254255
assert hamiltonian == hamiltonian + zero
255256
assert 2 * hamiltonian + zero == 2 * hamiltonian
256257

258+
geometries = [dataset[i].result() for i in [0, -1]]
259+
natoms = [len(geometry) for geometry in geometries]
260+
forces = zero.compute(geometries, 'forces', batch_size=1).result()
261+
assert np.all(forces[0, :natoms[0]] == 0.0)
262+
assert np.all(forces[-1, :natoms[1]] == 0.0)
263+
257264

258265
def test_subtract(dataset):
259266
einstein = EinsteinCrystal(dataset[0], force_constant=1.0)

0 commit comments

Comments
 (0)