17
17
from pymatgen .io .cif import CifParser
18
18
from pymatgen .io .vasp .inputs import Poscar
19
19
from pymatgen .io .vasp .sets import MPRelaxSet , VaspInputSet
20
+ from pymatgen .transformations .transformation_abc import AbstractTransformation
20
21
from pymatgen .util .provenance import StructureNL
21
22
22
23
if TYPE_CHECKING :
24
+ from collections .abc import Sequence
25
+
23
26
from pymatgen .alchemy .filters import AbstractStructureFilter
24
- from pymatgen .transformations .transformation_abc import AbstractTransformation
25
27
26
28
27
29
class TransformedStructure (MSONable ):
@@ -35,16 +37,15 @@ class TransformedStructure(MSONable):
35
37
def __init__ (
36
38
self ,
37
39
structure : Structure ,
38
- transformations : list [AbstractTransformation ] | None = None ,
40
+ transformations : AbstractTransformation | Sequence [AbstractTransformation ] | None = None ,
39
41
history : list [AbstractTransformation | dict [str , Any ]] | None = None ,
40
42
other_parameters : dict [str , Any ] | None = None ,
41
43
) -> None :
42
44
"""Initializes a transformed structure from a structure.
43
45
44
46
Args:
45
47
structure (Structure): Input structure
46
- transformations (list[Transformation]): List of transformations to
47
- apply.
48
+ transformations (list[Transformation]): List of transformations to apply.
48
49
history (list[Transformation]): Previous history.
49
50
other_parameters (dict): Additional parameters to be added.
50
51
"""
@@ -53,6 +54,8 @@ def __init__(
53
54
self .other_parameters = other_parameters or {}
54
55
self ._undone : list [tuple [AbstractTransformation | dict [str , Any ], Structure ]] = []
55
56
57
+ if isinstance (transformations , AbstractTransformation ):
58
+ transformations = [transformations ]
56
59
transformations = transformations or []
57
60
for trafo in transformations :
58
61
self .append_transformation (trafo )
@@ -86,8 +89,10 @@ def redo_next_change(self) -> None:
86
89
self .history .append (hist )
87
90
self .final_structure = struct
88
91
89
- def __getattr__ (self , name ) -> Any :
90
- return getattr (self .final_structure , name )
92
+ def __getattr__ (self , name : str ) -> Any :
93
+ # Don't use getattr(self.final_structure, name) here to avoid infinite recursion if name = "final_structure"
94
+ struct = self .__getattribute__ ("final_structure" )
95
+ return getattr (struct , name )
91
96
92
97
def __len__ (self ) -> int :
93
98
return len (self .history )
@@ -161,7 +166,9 @@ def append_filter(self, structure_filter: AbstractStructureFilter) -> None:
161
166
self .history .append (h_dict )
162
167
163
168
def extend_transformations (
164
- self , transformations : list [AbstractTransformation ], return_alternatives : bool = False
169
+ self ,
170
+ transformations : list [AbstractTransformation ],
171
+ return_alternatives : bool = False ,
165
172
) -> None :
166
173
"""Extends a sequence of transformations to the TransformedStructure.
167
174
@@ -209,7 +216,13 @@ def write_vasp_input(
209
216
json .dump (self .as_dict (), file )
210
217
211
218
def __str__ (self ) -> str :
212
- output = ["Current structure" , "------------" , str (self .final_structure ), "\n History" , "------------" ]
219
+ output = [
220
+ "Current structure" ,
221
+ "------------" ,
222
+ str (self .final_structure ),
223
+ "\n History" ,
224
+ "------------" ,
225
+ ]
213
226
for hist in self .history :
214
227
hist .pop ("input_structure" , None )
215
228
output .append (str (hist ))
@@ -290,7 +303,9 @@ def from_cif_str(
290
303
291
304
@classmethod
292
305
def from_poscar_str (
293
- cls , poscar_string : str , transformations : list [AbstractTransformation ] | None = None
306
+ cls ,
307
+ poscar_string : str ,
308
+ transformations : list [AbstractTransformation ] | None = None ,
294
309
) -> TransformedStructure :
295
310
"""Generates TransformedStructure from a poscar string.
296
311
0 commit comments