Skip to content

Commit

Permalink
fix XYZ reading bug
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Oct 30, 2024
1 parent e79ae8e commit 81b73a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
19 changes: 16 additions & 3 deletions psiflow/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,27 @@ def from_string(cls, s: str, natoms: Optional[int] = None) -> Optional[Geometry]
comment_dict = key_val_str_to_dict_regex(comment)

# read and format per_atom data
column_indices = {}
if 'Properties' in comment_dict:
properties = comment_dict['Properties'].split(':')
count = 0
for i in range(len(properties) // 3):
name = properties[3 * i]
ncolumns = int(properties[3 * i + 2])
column_indices[name] = count
count += ncolumns
assert 'pos' in column_indices # positions need to be there

per_atom = np.recarray(natoms, dtype=per_atom_dtype)
per_atom.forces[:] = np.nan
POS_INDEX = column_indices.get('pos', 1)
FORCES_INDEX = column_indices.get('forces', None)
for i in range(natoms):
values = lines[i + 1].split()
per_atom.numbers[i] = chemical_symbols.index(values[0])
per_atom.positions[i, :] = [float(_) for _ in values[1:4]]
if len(values) > 4:
per_atom.forces[i, :] = [float(_) for _ in values[4:7]]
per_atom.positions[i, :] = [float(_) for _ in values[POS_INDEX:POS_INDEX + 3]]
if FORCES_INDEX is not None:
per_atom.forces[i, :] = [float(_) for _ in values[FORCES_INDEX:FORCES_INDEX + 3]]

order = {}
for key, value in comment_dict.items():
Expand Down
14 changes: 9 additions & 5 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,25 @@ def test_readwrite_cycle(dataset, tmp_path):
assert "test" in states[2].order

s = """3
energy=3.0 phase=c7eq
C 0 1 2
O 1 2 3
F 4 5 6
energy=3.0 phase=c7eq Properties=species:S:1:pos:R:3:momenta:R:3:forces:R:3
C 0 1 2 3 4 5 6 7 8
O 1 2 3 4 5 6 7 8 9
F 2 3 4 5 6 7 8 9 10
"""
geometry = Geometry.from_string(s, natoms=None)
assert len(geometry) == 3
assert geometry.energy == 3.0
assert geometry.phase == "c7eq"
assert not geometry.periodic
assert np.all(np.isnan(geometry.per_atom.forces))
assert np.all(np.logical_not(np.isnan(geometry.per_atom.forces)))
assert np.allclose(
geometry.per_atom.numbers,
np.array([6, 8, 9]),
)
assert np.allclose(
geometry.per_atom.forces,
np.array([[6,7,8], [7,8,9], [8,9,10]]),
)
s = """7
O 0.269073490000000 0.952731530000000 0.639899630000000
Expand Down

0 comments on commit 81b73a4

Please sign in to comment.