Skip to content

Commit 05e8abb

Browse files
Add interface to icet SQS tools through SQSTransformation (materialsproject#3593)
* Add icet interface to io * fix Icet ClusterSpace generation for multi-site systems * Add test for icet SQS transformation * Add test for icet monte carlo * Correct type annotations for icet.ClusterSpace * add icet to optional requirements * use monty.dev.requires for IcetSQS not installed err msg * remove requires decorator, add import test, fix failed tests * link https://icet.materialsmodeling.org from class doc str * simplify IcetSQS class * refactor setting SQSTransformation.icet_sqs_kwargs defaults --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 3f89175 commit 05e8abb

File tree

5 files changed

+438
-31
lines changed

5 files changed

+438
-31
lines changed

pymatgen/io/icet.py

+328
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
from __future__ import annotations
2+
3+
import multiprocessing as multiproc
4+
import warnings
5+
from string import ascii_uppercase
6+
from time import time
7+
from typing import TYPE_CHECKING
8+
9+
from pymatgen.command_line.mcsqs_caller import Sqs
10+
from pymatgen.core import Structure
11+
from pymatgen.io.ase import AseAtomsAdaptor
12+
13+
try:
14+
from icet import ClusterSpace
15+
from icet.tools import enumerate_structures
16+
from icet.tools.structure_generation import _get_sqs_cluster_vector, _validate_concentrations, generate_sqs
17+
from mchammer.calculators import compare_cluster_vectors
18+
except ImportError:
19+
ClusterSpace = None
20+
21+
22+
if TYPE_CHECKING:
23+
from typing import Any
24+
25+
from _icet import _ClusterSpace
26+
from ase import Atoms
27+
28+
29+
class IcetSQS:
30+
"""Interface to the Icet library of SQS structure generation tools.
31+
32+
https://icet.materialsmodeling.org
33+
"""
34+
35+
sqs_kwarg_names: dict[str, tuple[str, ...]] = {
36+
"monte_carlo": (
37+
"include_smaller_cells",
38+
"pbc",
39+
"T_start",
40+
"T_stop",
41+
"n_steps",
42+
"optimality_weight",
43+
"random_seed",
44+
"tol",
45+
),
46+
"enumeration": ("include_smaller_cells", "pbc", "optimality_weight", "tol"),
47+
}
48+
_sqs_kwarg_defaults: dict[str, Any] = {
49+
"optimality_weight": None,
50+
"tol": 1.0e-5,
51+
"include_smaller_cells": False, # for consistency with ATAT
52+
"pbc": (True, True, True),
53+
}
54+
sqs_methods: tuple[str, ...] = ("enumeration", "monte_carlo")
55+
56+
def __init__(
57+
self,
58+
structure: Structure,
59+
scaling: int,
60+
instances: int | None,
61+
cluster_cutoffs: dict[int, float],
62+
sqs_method: str | None = None,
63+
sqs_kwargs: dict | None = None,
64+
) -> None:
65+
"""
66+
Instantiate an IcetSQS interface.
67+
68+
Args:
69+
structure (Structure): disordered structure to compute SQS
70+
scaling (int): SQS supercell contains scaling * len(structure) sites
71+
instances (int): number of parallel SQS jobs to run
72+
cluster_cutoffs (dict): dict of cluster size (pairs, triplets, ...) and
73+
the size of the cluster
74+
Kwargs:
75+
sqs_method (str or None): if a str, one of ("enumeration", "monte_carlo")
76+
If None, default to "enumeration" for a supercell of < 24 sites, and
77+
"monte carlo" otherwise.
78+
sqs_kwargs (dict): kwargs to pass to the icet SQS generators.
79+
See self.sqs_kwarg_names for possible options.
80+
81+
Returns:
82+
None
83+
"""
84+
if ClusterSpace is None:
85+
raise ImportError("IcetSQS requires the icet package. Use `pip install icet`")
86+
87+
self._structure = structure
88+
self.scaling = scaling
89+
self.instances = instances or multiproc.cpu_count()
90+
91+
self._get_site_composition()
92+
93+
# The peculiar way that icet works requires a copy of the
94+
# disordered structure, but without any fractionally-occupied sites
95+
# Essentially the host structure
96+
_ordered_structure = structure.copy()
97+
98+
original_composition = _ordered_structure.composition.as_dict()
99+
dummy_comp = next(iter(_ordered_structure.composition))
100+
_ordered_structure.replace_species(
101+
{species: dummy_comp for species in original_composition if species != dummy_comp}
102+
)
103+
self._ordered_atoms = AseAtomsAdaptor.get_atoms(_ordered_structure)
104+
105+
self.cutoffs_list = []
106+
for i in range(2, max(cluster_cutoffs.keys()) + 1):
107+
if i not in cluster_cutoffs:
108+
# pad missing non-sequential values
109+
cluster_cutoffs[i] = 0.0
110+
self.cutoffs_list.append(cluster_cutoffs[i])
111+
112+
# For safety, enumeration works well on 1 core for ~< 24 sites/cell
113+
# The bottleneck is **generation** of the structures via enumeration,
114+
# less checking their SQS objective.
115+
# Beyond ~24 sites/cell, monte carlo is more efficient
116+
sqs_method = sqs_method or ("enumeration" if self.scaling * len(self._structure) < 24 else "monte_carlo")
117+
118+
# Default sqs_kwargs
119+
self.sqs_kwargs = self._sqs_kwarg_defaults.copy()
120+
self.sqs_kwargs.update(sqs_kwargs or {})
121+
122+
unrecognized_kwargs = {key for key in self.sqs_kwargs if key not in self.sqs_kwarg_names[sqs_method]}
123+
if len(unrecognized_kwargs) > 0:
124+
warnings.warn(f"Ignoring unrecognized icet {sqs_method} kwargs: {', '.join(unrecognized_kwargs)}")
125+
126+
self.sqs_kwargs = {
127+
key: value for key, value in self.sqs_kwargs.items() if key in self.sqs_kwarg_names[sqs_method]
128+
}
129+
130+
if sqs_method == "monte_carlo":
131+
self.sqs_getter = self.monte_carlo_sqs_structures
132+
if self.sqs_kwargs.get("random_seed") is None:
133+
self.sqs_kwargs["random_seed"] = int(1e6 * time())
134+
135+
elif sqs_method == "enumeration":
136+
self.sqs_getter = self.enumerate_sqs_structures
137+
138+
else:
139+
raise ValueError(f"Unknown {sqs_method=}! Must be one of {self.sqs_methods}")
140+
141+
self._sqs_obj_kwargs = {}
142+
for key in ("optimality_weight", "tol"):
143+
if value := self.sqs_kwargs.get(key, self._sqs_kwarg_defaults[key]):
144+
self._sqs_obj_kwargs[key] = value
145+
146+
cluster_space = self._get_cluster_space()
147+
self.target_concentrations = _validate_concentrations(
148+
concentrations=self.composition, cluster_space=cluster_space
149+
)
150+
self.sqs_vector = _get_sqs_cluster_vector(
151+
cluster_space=cluster_space, target_concentrations=self.target_concentrations
152+
)
153+
154+
def run(self) -> Sqs:
155+
"""
156+
Run the SQS search with icet.
157+
158+
Returns:
159+
pymatgen Sqs object
160+
"""
161+
162+
sqs_structures = self.sqs_getter()
163+
for idx in range(len(sqs_structures)):
164+
sqs_structures[idx]["structure"] = AseAtomsAdaptor.get_structure(sqs_structures[idx]["structure"])
165+
sqs_structures = sorted(sqs_structures, key=lambda entry: entry["objective_function"])
166+
167+
return Sqs(
168+
bestsqs=sqs_structures[0]["structure"],
169+
objective_function=sqs_structures[0]["objective_function"],
170+
allsqs=sqs_structures,
171+
directory="./",
172+
clusters=str(self._get_cluster_space()),
173+
)
174+
175+
def _get_site_composition(self) -> None:
176+
"""
177+
Get Icet-format composition from structure.
178+
179+
Returns:
180+
Dict with sublattice compositions specified by uppercase letters,
181+
e.g., In_x Ga_1-x As becomes:
182+
{
183+
"A": {"In": x, "Ga": 1 - x},
184+
"B": {"As": 1}
185+
}
186+
"""
187+
uppercase_letters = list(ascii_uppercase)
188+
idx = 0
189+
self.composition: dict[str, dict] = {}
190+
for idx, site in enumerate(self._structure):
191+
site_comp = site.species.as_dict()
192+
if site_comp not in self.composition.values():
193+
self.composition[uppercase_letters[idx]] = site_comp
194+
idx += 1
195+
196+
def _get_cluster_space(self) -> ClusterSpace:
197+
"""Generate the ClusterSpace object for icet."""
198+
chemical_symbols = [list(site.species.as_dict()) for site in self._structure]
199+
return ClusterSpace(structure=self._ordered_atoms, cutoffs=self.cutoffs_list, chemical_symbols=chemical_symbols)
200+
201+
def get_icet_sqs_obj(self, material: Atoms | Structure, cluster_space: _ClusterSpace | None = None) -> float:
202+
"""
203+
Get the SQS objective function.
204+
205+
Args:
206+
material (ase Atoms or pymatgen Structure) : structure to
207+
compute SQS objective function.
208+
Kwargs:
209+
cluster_space (ClusterSpace) : ClusterSpace of the SQS search.
210+
211+
Returns:
212+
float : the SQS objective function
213+
"""
214+
if isinstance(material, Structure):
215+
material = AseAtomsAdaptor.get_atoms(material)
216+
217+
cluster_space = cluster_space or self._get_cluster_space()
218+
return compare_cluster_vectors(
219+
cv_1=cluster_space.get_cluster_vector(material),
220+
cv_2=self.sqs_vector,
221+
orbit_data=cluster_space.orbit_data,
222+
**self._sqs_obj_kwargs,
223+
)
224+
225+
def enumerate_sqs_structures(self, cluster_space: _ClusterSpace | None = None) -> list:
226+
"""
227+
Generate an SQS by enumeration of all possible arrangements.
228+
229+
Adapted from icet.tools.structure_generation.generate_sqs_by_enumeration
230+
to accommodate multiprocessing.
231+
232+
Kwargs:
233+
cluster_space (ClusterSpace) : ClusterSpace of the SQS search.
234+
235+
Returns:
236+
list : a list of dicts of the form: {
237+
"structure": SQS structure,
238+
"objective_function": SQS objective function,
239+
}
240+
"""
241+
242+
# Translate concentrations to the format required for concentration
243+
# restricted enumeration
244+
cr: dict[str, tuple] = {}
245+
cluster_space = cluster_space or self._get_cluster_space()
246+
sub_lattices = cluster_space.get_sublattices(cluster_space.primitive_structure)
247+
for sl in sub_lattices:
248+
mult_factor = len(sl.indices) / len(cluster_space.primitive_structure)
249+
if sl.symbol in self.target_concentrations:
250+
sl_conc = self.target_concentrations[sl.symbol]
251+
else:
252+
sl_conc = {sl.chemical_symbols[0]: 1.0}
253+
for species, value in sl_conc.items():
254+
c = value * mult_factor
255+
if species in cr:
256+
cr[species] = (cr[species][0] + c, cr[species][1] + c)
257+
else:
258+
cr[species] = (c, c)
259+
260+
# Check to be sure...
261+
c_sum = sum(c[0] for c in cr.values())
262+
if abs(c_sum - 1) >= self.sqs_kwargs["tol"]:
263+
raise ValueError(f"Site occupancies sum to {abs(c_sum - 1)} instead of 1!")
264+
265+
sizes = list(range(1, self.scaling + 1)) if self.sqs_kwargs["include_smaller_cells"] else [self.scaling]
266+
267+
# Prepare primitive structure with the right boundary conditions
268+
prim = cluster_space.primitive_structure
269+
prim.set_pbc(self.sqs_kwargs["pbc"])
270+
271+
structures = enumerate_structures(prim, sizes, cluster_space.chemical_symbols, concentration_restrictions=cr)
272+
chunks: list[list[Atoms]] = [[] for _ in range(self.instances)]
273+
proc_idx = 0
274+
for structure in structures:
275+
chunks[proc_idx].append(structure)
276+
proc_idx = (proc_idx + 1) % self.instances
277+
278+
manager = multiproc.Manager()
279+
working_list = manager.list()
280+
processes = []
281+
for proc_idx in range(self.instances):
282+
process = multiproc.Process(
283+
target=self._get_best_sqs_from_list,
284+
args=(chunks[proc_idx], working_list),
285+
)
286+
processes.append(process)
287+
process.start()
288+
289+
for process in processes:
290+
process.join()
291+
292+
return list(working_list)
293+
294+
def _get_best_sqs_from_list(self, structures: list[Atoms], output_list: list[dict]) -> None:
295+
"""
296+
Find best SQS structure from list of SQS structures.
297+
298+
Args:
299+
structures (list of ase Atoms) : list of SQS structures
300+
output_list (list of dicts) : shared list between
301+
multiprocessing processes to store best SQS objects.
302+
"""
303+
best_sqs: dict[str, Any] = {"structure": None, "objective_function": 1.0e20}
304+
cluster_space = self._get_cluster_space()
305+
for structure in structures:
306+
objective = self.get_icet_sqs_obj(structure, cluster_space=cluster_space)
307+
if objective < best_sqs["objective_function"]:
308+
best_sqs = {"structure": structure, "objective_function": objective}
309+
output_list.append(best_sqs)
310+
311+
def _single_monte_carlo_sqs_run(self):
312+
"""Run a single Monte Carlo SQS search with Icet."""
313+
cluster_space = self._get_cluster_space()
314+
sqs_structure = generate_sqs(
315+
cluster_space=cluster_space,
316+
max_size=self.scaling,
317+
target_concentrations=self.target_concentrations,
318+
**self.sqs_kwargs,
319+
)
320+
return {
321+
"structure": sqs_structure,
322+
"objective_function": self.get_icet_sqs_obj(sqs_structure, cluster_space=cluster_space),
323+
}
324+
325+
def monte_carlo_sqs_structures(self) -> list:
326+
"""Run `self.instances` Monte Carlo SQS search with Icet."""
327+
with multiproc.Pool(self.instances) as pool:
328+
return pool.starmap(self._single_monte_carlo_sqs_run, [() for _ in range(self.instances)])

0 commit comments

Comments
 (0)