Skip to content

Commit

Permalink
atlist to ati and back conversion (#81)
Browse files Browse the repository at this point in the history
* implement atlist to ati conversion

Signed-off-by: Marcel Müller <[email protected]>

* add conversion test

Signed-off-by: Marcel Müller <[email protected]>

---------

Signed-off-by: Marcel Müller <[email protected]>
  • Loading branch information
marcelmbn authored Nov 18, 2024
1 parent 82aa3f5 commit 8c3d823
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/mindlessgen/molecules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module contains all molecule-related functionality.
"""

from .molecule import Molecule, PSE_NUMBERS, PSE_SYMBOLS
from .molecule import Molecule, PSE_NUMBERS, PSE_SYMBOLS, ati_to_atlist, atlist_to_ati
from .generate_molecule import (
generate_random_molecule,
generate_coordinates,
Expand Down Expand Up @@ -40,5 +40,7 @@
"get_alkaline_earth_metals",
"PSE_NUMBERS",
"PSE_SYMBOLS",
"ati_to_atlist",
"atlist_to_ati",
"postprocess_mol",
]
26 changes: 26 additions & 0 deletions src/mindlessgen/molecules/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,29 @@ def set_name_from_formula(self) -> None:
# add a random hash to the name
hashname = hashlib.sha256(self.rng.bytes(32)).hexdigest()[:6]
self.name = f"{molname}_{hashname}"


def ati_to_atlist(ati: np.ndarray) -> np.ndarray:
"""
Convert the atomic number per index to the array with the number of atoms of each element.
:param ati: The atomic number per index.
:return: The array with the number of atoms of each element.
"""
atlist = np.zeros(103, dtype=int)
for atomtype in ati:
atlist[atomtype] += 1
return atlist


def atlist_to_ati(atlist: np.ndarray) -> np.ndarray:
"""
Convert the array with the number of atoms of each element to the atomic number per index.
:param atlist: The array with the number of atoms of each element.
:return: The atomic number per index.
"""
ati = np.array([], dtype=int)
for i, num in enumerate(atlist):
ati = np.append(ati, np.full(shape=num, fill_value=i))
return ati
38 changes: 37 additions & 1 deletion test/test_molecules/test_molecule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
import numpy as np
import pytest
from mindlessgen.molecules.molecule import Molecule # type: ignore
from mindlessgen.molecules.molecule import Molecule, ati_to_atlist, atlist_to_ati # type: ignore


# load the molecule: C2H4N1O1Au1
Expand Down Expand Up @@ -249,3 +249,39 @@ def test_set_name_from_formula():

mol.set_name_from_formula()
assert "C2H4_" in mol.name # Ensure the generated name includes the correct formula


def test_atlist_to_ati():
"""
Test the atlist_to_ati function.
"""
atlist = np.zeros(103, dtype=int)
atlist[0] = 1
atlist[4] = 2
atlist[5] = 2
atlist[6] = 1
atlist[7] = 1
atlist[43] = 2
atlist[58] = 1

ati = atlist_to_ati(atlist)
np.testing.assert_array_equal(ati, np.array([0, 4, 4, 5, 5, 6, 7, 43, 43, 58]))


def test_ati_to_atlist():
"""
Test the ati_to_atlist function.
"""
ati = np.array([0, 4, 4, 5, 5, 6, 7, 43, 43, 58])

atlist = ati_to_atlist(ati)
expected_atlist = np.zeros(103, dtype=int)
expected_atlist[0] = 1
expected_atlist[4] = 2
expected_atlist[5] = 2
expected_atlist[6] = 1
expected_atlist[7] = 1
expected_atlist[43] = 2
expected_atlist[58] = 1

np.testing.assert_array_equal(atlist, expected_atlist)

0 comments on commit 8c3d823

Please sign in to comment.