From 8c3d8236e07f2cbdb0d17a6efbaeeefc573fd059 Mon Sep 17 00:00:00 2001 From: Marcel Mueller Date: Mon, 18 Nov 2024 09:59:31 +0100 Subject: [PATCH] `atlist` to `ati` and back conversion (#81) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement atlist to ati conversion Signed-off-by: Marcel Müller * add conversion test Signed-off-by: Marcel Müller --------- Signed-off-by: Marcel Müller --- src/mindlessgen/molecules/__init__.py | 4 ++- src/mindlessgen/molecules/molecule.py | 26 ++++++++++++++++++ test/test_molecules/test_molecule.py | 38 ++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/mindlessgen/molecules/__init__.py b/src/mindlessgen/molecules/__init__.py index fabdeb8..5787094 100644 --- a/src/mindlessgen/molecules/__init__.py +++ b/src/mindlessgen/molecules/__init__.py @@ -2,7 +2,7 @@ This module contains all molecule-related functionality. """ -from .molecule import Molecule, PSE_NUMBERS, PSE_SYMBOLS +from .molecule import Molecule, PSE_NUMBERS, PSE_SYMBOLS, ati_to_atlist, atlist_to_ati from .generate_molecule import ( generate_random_molecule, generate_coordinates, @@ -40,5 +40,7 @@ "get_alkaline_earth_metals", "PSE_NUMBERS", "PSE_SYMBOLS", + "ati_to_atlist", + "atlist_to_ati", "postprocess_mol", ] diff --git a/src/mindlessgen/molecules/molecule.py b/src/mindlessgen/molecules/molecule.py index 905a878..e497962 100644 --- a/src/mindlessgen/molecules/molecule.py +++ b/src/mindlessgen/molecules/molecule.py @@ -641,3 +641,29 @@ def set_name_from_formula(self) -> None: # add a random hash to the name hashname = hashlib.sha256(self.rng.bytes(32)).hexdigest()[:6] self.name = f"{molname}_{hashname}" + + +def ati_to_atlist(ati: np.ndarray) -> np.ndarray: + """ + Convert the atomic number per index to the array with the number of atoms of each element. + + :param ati: The atomic number per index. + :return: The array with the number of atoms of each element. + """ + atlist = np.zeros(103, dtype=int) + for atomtype in ati: + atlist[atomtype] += 1 + return atlist + + +def atlist_to_ati(atlist: np.ndarray) -> np.ndarray: + """ + Convert the array with the number of atoms of each element to the atomic number per index. + + :param atlist: The array with the number of atoms of each element. + :return: The atomic number per index. + """ + ati = np.array([], dtype=int) + for i, num in enumerate(atlist): + ati = np.append(ati, np.full(shape=num, fill_value=i)) + return ati diff --git a/test/test_molecules/test_molecule.py b/test/test_molecules/test_molecule.py index ca8b000..2bd1a4d 100644 --- a/test/test_molecules/test_molecule.py +++ b/test/test_molecules/test_molecule.py @@ -1,7 +1,7 @@ from pathlib import Path import numpy as np import pytest -from mindlessgen.molecules.molecule import Molecule # type: ignore +from mindlessgen.molecules.molecule import Molecule, ati_to_atlist, atlist_to_ati # type: ignore # load the molecule: C2H4N1O1Au1 @@ -249,3 +249,39 @@ def test_set_name_from_formula(): mol.set_name_from_formula() assert "C2H4_" in mol.name # Ensure the generated name includes the correct formula + + +def test_atlist_to_ati(): + """ + Test the atlist_to_ati function. + """ + atlist = np.zeros(103, dtype=int) + atlist[0] = 1 + atlist[4] = 2 + atlist[5] = 2 + atlist[6] = 1 + atlist[7] = 1 + atlist[43] = 2 + atlist[58] = 1 + + ati = atlist_to_ati(atlist) + np.testing.assert_array_equal(ati, np.array([0, 4, 4, 5, 5, 6, 7, 43, 43, 58])) + + +def test_ati_to_atlist(): + """ + Test the ati_to_atlist function. + """ + ati = np.array([0, 4, 4, 5, 5, 6, 7, 43, 43, 58]) + + atlist = ati_to_atlist(ati) + expected_atlist = np.zeros(103, dtype=int) + expected_atlist[0] = 1 + expected_atlist[4] = 2 + expected_atlist[5] = 2 + expected_atlist[6] = 1 + expected_atlist[7] = 1 + expected_atlist[43] = 2 + expected_atlist[58] = 1 + + np.testing.assert_array_equal(atlist, expected_atlist)