File tree 2 files changed +27
-1
lines changed
2 files changed +27
-1
lines changed Original file line number Diff line number Diff line change @@ -415,14 +415,33 @@ def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]:
415
415
Note:
416
416
This function is wrapped as a Parsl app and executed using the default_threads executor.
417
417
"""
418
+ def pad_arrays (
419
+ arrays : list [np .ndarray ],
420
+ pad_dimension : int = 1 ,
421
+ ) -> list [np .ndarray ]:
422
+ ndims = np .array ([len (a .shape ) for a in arrays ])
423
+ assert np .all (ndims == ndims [0 ])
424
+ assert np .all (pad_dimension < ndims )
425
+
426
+ pad_size = max ([a .shape [pad_dimension ] for a in arrays ])
427
+ for i in range (len (arrays )):
428
+ shape = list (arrays [i ].shape )
429
+ shape [pad_dimension ] = pad_size - shape [pad_dimension ]
430
+ padding = np .zeros (tuple (shape )) + np .nan
431
+ arrays [i ] = np .concatenate ((arrays [i ], padding ), axis = pad_dimension )
432
+ return arrays
433
+
418
434
narrays = len (args [0 ])
419
435
for arg in args :
420
436
assert isinstance (arg , list )
421
437
assert all ([len (a ) == narrays for a in args ])
422
438
423
439
concatenated = []
424
440
for i in range (narrays ):
425
- concatenated .append (np .concatenate ([arg [i ] for arg in args ]))
441
+ arrays = [arg [i ] for arg in args ]
442
+ if len (arrays [0 ].shape ) > 1 :
443
+ pad_arrays (arrays )
444
+ concatenated .append (np .concatenate (tuple (arrays )))
426
445
return concatenated
427
446
428
447
Original file line number Diff line number Diff line change 5
5
from parsl .data_provider .files import File # type: ignore
6
6
7
7
import psiflow
8
+ from psiflow .data import Dataset
8
9
from psiflow .functions import (
9
10
EinsteinCrystalFunction ,
10
11
HarmonicFunction ,
@@ -254,6 +255,12 @@ def test_hamiltonian_arithmetic(dataset):
254
255
assert hamiltonian == hamiltonian + zero
255
256
assert 2 * hamiltonian + zero == 2 * hamiltonian
256
257
258
+ geometries = [dataset [i ].result () for i in [0 , - 1 ]]
259
+ natoms = [len (geometry ) for geometry in geometries ]
260
+ forces = zero .compute (geometries , 'forces' , batch_size = 1 ).result ()
261
+ assert np .all (forces [0 , :natoms [0 ]] == 0.0 )
262
+ assert np .all (forces [- 1 , :natoms [1 ]] == 0.0 )
263
+
257
264
258
265
def test_subtract (dataset ):
259
266
einstein = EinsteinCrystal (dataset [0 ], force_constant = 1.0 )
You can’t perform that action at this time.
0 commit comments