Skip to content

Commit 0234182

Browse files
Improve Keep Redundant Spaces algorithm for PatchedPhaseDiagram (materialsproject#3900)
* fix: old algorithm to deduplicate spaces didn't find the minimum subset * test: direct test for remove_redundant_spaces static method * doc: clean up old comments, add details explaining why patchedphasediagram as_dict doesn't save computations due to shared memory id issue. * pre-commit auto-fixes * lint: spelling --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 454aa5e commit 0234182

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

src/pymatgen/analysis/phase_diagram.py

+61-35
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,6 @@ def get_decomp_and_phase_separation_energy(
903903
}
904904

905905
# NOTE calling PhaseDiagram is only reasonable if the composition has fewer than 5 elements
906-
# TODO can we call PatchedPhaseDiagram here?
907906
inner_hull = PhaseDiagram(reduced_space)
908907

909908
competing_entries = inner_hull.stable_entries | {*self._get_stable_entries_in_space(entry_elems)}
@@ -1036,8 +1035,6 @@ def get_critical_compositions(self, comp1, comp2):
10361035
if np.all(c1 == c2):
10371036
return [comp1.copy(), comp2.copy()]
10381037

1039-
# NOTE made into method to facilitate inheritance of this method
1040-
# in PatchedPhaseDiagram if approximate solution can be found.
10411038
intersections = self._get_simplex_intersections(c1, c2)
10421039

10431040
# find position along line
@@ -1619,29 +1616,21 @@ def __init__(
16191616
# Add the elemental references
16201617
inds.extend([min_entries.index(el) for el in el_refs.values()])
16211618

1622-
self.qhull_entries = tuple(min_entries[idx] for idx in inds)
1619+
qhull_entries = tuple(min_entries[idx] for idx in inds)
16231620
# make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable
16241621
# prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically)
1625-
self._qhull_spaces = tuple(frozenset(entry.elements) for entry in self.qhull_entries)
1622+
qhull_spaces = tuple(frozenset(entry.elements) for entry in qhull_entries)
16261623

16271624
# Get all unique chemical spaces
1628-
spaces = {s for s in self._qhull_spaces if len(s) > 1}
1625+
spaces = {s for s in qhull_spaces if len(s) > 1}
16291626

16301627
# Remove redundant chemical spaces
1631-
if not keep_all_spaces and len(spaces) > 1:
1632-
max_size = max(len(s) for s in spaces)
1633-
1634-
systems = set()
1635-
# NOTE reduce the number of comparisons by only comparing to larger sets
1636-
for idx in range(2, max_size + 1):
1637-
test = (s for s in spaces if len(s) == idx)
1638-
refer = (s for s in spaces if len(s) > idx)
1639-
systems |= {t for t in test if not any(t.issubset(r) for r in refer)}
1640-
1641-
spaces = systems
1628+
spaces = self.remove_redundant_spaces(spaces, keep_all_spaces)
16421629

16431630
# TODO comprhys: refactor to have self._compute method to allow serialization
1644-
self.spaces = sorted(spaces, key=len, reverse=False) # Calculate pds for smaller dimension spaces first
1631+
self.spaces = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last
1632+
self.qhull_entries = qhull_entries
1633+
self._qhull_spaces = qhull_spaces
16451634
self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=not verbose))
16461635
self.all_entries = all_entries
16471636
self.el_refs = el_refs
@@ -1675,7 +1664,19 @@ def __contains__(self, item: frozenset[Element]) -> bool:
16751664
return item in self.pds
16761665

16771666
def as_dict(self) -> dict[str, Any]:
1678-
"""
1667+
"""Write the entries and elements used to construct the PatchedPhaseDiagram
1668+
to a dictionary.
1669+
1670+
NOTE unlike PhaseDiagram the computation involved in constructing the
1671+
PatchedPhaseDiagram is not saved on serialisation. This is done because
1672+
hierarchically calling the `PhaseDiagram.as_dict()` method would break the
1673+
link in memory between entries in overlapping patches leading to a
1674+
ballooning of the amount of memory used.
1675+
1676+
NOTE For memory efficiency the best way to store patched phase diagrams is
1677+
via pickling. As this allows all the entries in overlapping patches to share
1678+
the same id in memory when unpickling.
1679+
16791680
Returns:
16801681
dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram.
16811682
"""
@@ -1688,7 +1689,18 @@ def as_dict(self) -> dict[str, Any]:
16881689

16891690
@classmethod
16901691
def from_dict(cls, dct: dict) -> Self:
1691-
"""
1692+
"""Reconstruct PatchedPhaseDiagram from dictionary serialisation.
1693+
1694+
NOTE unlike PhaseDiagram the computation involved in constructing the
1695+
PatchedPhaseDiagram is not saved on serialisation. This is done because
1696+
hierarchically calling the `PhaseDiagram.as_dict()` method would break the
1697+
link in memory between entries in overlapping patches leading to a
1698+
ballooning of the amount of memory used.
1699+
1700+
NOTE For memory efficiency the best way to store patched phase diagrams is
1701+
via pickling. As this allows all the entries in overlapping patches to share
1702+
the same id in memory when unpickling.
1703+
16921704
Args:
16931705
dct (dict): dictionary representation of PatchedPhaseDiagram.
16941706
@@ -1699,9 +1711,23 @@ def from_dict(cls, dct: dict) -> Self:
16991711
elements = [Element.from_dict(elem) for elem in dct["elements"]]
17001712
return cls(entries, elements)
17011713

1714+
@staticmethod
1715+
def remove_redundant_spaces(spaces, keep_all_spaces=False):
1716+
if keep_all_spaces or len(spaces) <= 1:
1717+
return spaces
1718+
1719+
# Sort spaces by size in descending order and pre-compute lengths
1720+
sorted_spaces = sorted(spaces, key=len, reverse=True)
1721+
1722+
result = []
1723+
for i, space_i in enumerate(sorted_spaces):
1724+
if not any(space_i.issubset(larger_space) for larger_space in sorted_spaces[:i]):
1725+
result.append(space_i)
1726+
1727+
return result
1728+
17021729
# NOTE following methods are inherited unchanged from PhaseDiagram:
17031730
# __repr__,
1704-
# as_dict,
17051731
# all_entries_hulldata,
17061732
# unstable_entries,
17071733
# stable_entries,
@@ -1771,8 +1797,6 @@ def get_equilibrium_reaction_energy(self, entry: Entry) -> float:
17711797
"""
17721798
return self.get_phase_separation_energy(entry, stable_only=True)
17731799

1774-
# NOTE the following functions are not implemented for PatchedPhaseDiagram
1775-
17761800
def get_decomp_and_e_above_hull(
17771801
self,
17781802
entry: PDEntry,
@@ -1787,6 +1811,20 @@ def get_decomp_and_e_above_hull(
17871811
entry=entry, allow_negative=allow_negative, check_stable=check_stable, on_error=on_error
17881812
)
17891813

1814+
def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
1815+
"""
1816+
Args:
1817+
space (frozenset[Element]): chemical space of the form A-B-X.
1818+
1819+
Returns:
1820+
space, PhaseDiagram for the given chemical space
1821+
"""
1822+
space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)]
1823+
1824+
return space, PhaseDiagram(space_entries)
1825+
1826+
# NOTE the following functions are not implemented for PatchedPhaseDiagram
1827+
17901828
def _get_facet_and_simplex(self):
17911829
"""Not Implemented - See PhaseDiagram."""
17921830
raise NotImplementedError("_get_facet_and_simplex() not implemented for PatchedPhaseDiagram")
@@ -1835,18 +1873,6 @@ def get_chempot_range_stability_phase(self):
18351873
"""Not Implemented - See PhaseDiagram."""
18361874
raise NotImplementedError("get_chempot_range_stability_phase() not implemented for PatchedPhaseDiagram")
18371875

1838-
def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
1839-
"""
1840-
Args:
1841-
space (frozenset[Element]): chemical space of the form A-B-X.
1842-
1843-
Returns:
1844-
space, PhaseDiagram for the given chemical space
1845-
"""
1846-
space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)]
1847-
1848-
return space, PhaseDiagram(space_entries)
1849-
18501876

18511877
class ReactionDiagram:
18521878
"""

tests/analysis/test_phase_diagram.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import unittest
55
import unittest.mock
6+
from itertools import combinations
67
from numbers import Number
78
from unittest import TestCase
89

@@ -709,6 +710,7 @@ def setUp(self):
709710

710711
self.pd = PhaseDiagram(entries=self.entries)
711712
self.ppd = PatchedPhaseDiagram(entries=self.entries)
713+
self.ppd_all = PatchedPhaseDiagram(entries=self.entries, keep_all_spaces=True)
712714

713715
# novel entries not in any of the patches
714716
self.novel_comps = [Composition("H5C2OP"), Composition("V2PH4C")]
@@ -756,7 +758,11 @@ def test_dimensionality(self):
756758

757759
# test dims of sub PDs
758760
dim_counts = collections.Counter(pd.dim for pd in self.ppd.pds.values())
759-
assert dim_counts == {3: 7, 2: 6, 4: 2}
761+
assert dim_counts == {4: 2, 3: 2}
762+
763+
# test dims of sub PDs
764+
dim_counts = collections.Counter(pd.dim for pd in self.ppd_all.pds.values())
765+
assert dim_counts == {2: 8, 3: 7, 4: 2}
760766

761767
def test_get_hull_energy(self):
762768
for comp in self.novel_comps:
@@ -772,7 +778,7 @@ def test_get_decomp_and_e_above_hull(self):
772778
assert np.isclose(e_above_hull_pd, e_above_hull_ppd)
773779

774780
def test_repr(self):
775-
assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 15 sub-spaces"
781+
assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 4 sub-spaces"
776782

777783
def test_as_from_dict(self):
778784
ppd_dict = self.ppd.as_dict()
@@ -810,7 +816,8 @@ def test_getitem(self):
810816
pd = self.ppd[chem_space]
811817
assert isinstance(pd, PhaseDiagram)
812818
assert chem_space in pd._qhull_spaces
813-
assert str(pd) == "V-C phase diagram\n4 stable phases: \nC, V, V6C5, V2C"
819+
assert len(str(pd)) == 186
820+
assert str(pd).startswith("V-H-C-O phase diagram\n25 stable phases:")
814821

815822
with pytest.raises(KeyError, match="frozenset"):
816823
self.ppd[frozenset(map(Element, "HBCNOFPS"))]
@@ -830,6 +837,19 @@ def test_setitem_and_delitem(self):
830837
assert self.ppd[unlikely_chem_space] == self.pd
831838
del self.ppd[unlikely_chem_space] # test __delitem__() and restore original state
832839

840+
def test_remove_redundant_spaces(self):
841+
spaces = tuple(frozenset(entry.elements) for entry in self.ppd.qhull_entries)
842+
# NOTE this is 5 not 4 as "He" is a non redundant space that gets dropped for other reasons
843+
assert len(self.ppd.remove_redundant_spaces(spaces)) == 5
844+
845+
test = (
846+
list(combinations(range(1, 7), 4))
847+
+ list(combinations(range(1, 10), 2))
848+
+ list(combinations([1, 4, 7, 9, 2], 5))
849+
)
850+
test = [frozenset(t) for t in test]
851+
assert len(self.ppd.remove_redundant_spaces(test)) == 30
852+
833853

834854
class TestReactionDiagram(TestCase):
835855
def setUp(self):

0 commit comments

Comments
 (0)