Skip to content

Commit 5456812

Browse files
fix type hinting / mypy
1 parent aaf7ffc commit 5456812

File tree

2 files changed

+49
-59
lines changed

2 files changed

+49
-59
lines changed

emmet-core/emmet/core/math.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
11
"""Define types used for array-like data."""
2-
from typing import TypeVar
32

4-
Vector3D = TypeVar("Vector3D", bound=tuple[float, float, float])
5-
Vector3D.__doc__ = "Real space vector" # type: ignore
3+
Vector3D = tuple[float, float, float]
4+
"""Real space vector."""
65

7-
Matrix3D = TypeVar("Matrix3D", bound=tuple[Vector3D, Vector3D, Vector3D])
8-
Matrix3D.__doc__ = "Real space Matrix" # type: ignore
6+
Matrix3D = tuple[Vector3D, Vector3D, Vector3D]
7+
"""Real space Matrix."""
98

10-
Vector6D = TypeVar("Vector6D", bound=tuple[float, float, float, float, float, float])
11-
Vector6D.__doc__ = "6D Voigt matrix component" # type: ignore
9+
Vector6D = tuple[float, float, float, float, float, float]
10+
"""6D Voigt matrix component."""
1211

13-
MatrixVoigt = TypeVar(
14-
"MatrixVoigt",
15-
bound=tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D],
16-
)
17-
MatrixVoigt.__doc__ = "Voigt representation of a 3x3x3x3 tensor" # type: ignore
12+
MatrixVoigt = tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D]
13+
""""Voigt representation of a 3x3x3x3 tensor."""
1814

19-
Tensor3R = TypeVar("Tensor3R", bound=list[list[list[float]]])
20-
Tensor3R.__doc__ = "Generic tensor of rank 3" # type: ignore
15+
Tensor3R = list[list[list[float]]]
16+
"""Generic tensor of rank 3."""
2117

22-
Tensor4R = TypeVar("Tensor4R", bound=list[list[list[list[float]]]])
23-
Tensor4R.__doc__ = "Generic tensor of rank 4" # type: ignore
18+
Tensor4R = list[list[list[list[float]]]]
19+
"""Generic tensor of rank 4."""
2420

25-
ListVector3D = TypeVar("ListVector3D", bound=list[float])
26-
ListVector3D.__doc__ = "Real space vector as list" # type: ignore
21+
ListVector3D = list[float]
22+
"""Real space vector as list."""
2723

28-
ListMatrix3D = TypeVar("ListMatrix3D", bound=list[ListVector3D])
29-
ListMatrix3D.__doc__ = "Real space Matrix as list" # type: ignore
24+
ListMatrix3D = list[ListVector3D]
25+
"""Real space Matrix as list."""

emmet-core/emmet/core/structure_replicas.py

+33-39
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def to_pymatgen(self) -> Any:
3535
@classmethod
3636
def from_dict(cls, dct: dict[str, Any]) -> Self:
3737
"""MSONable-like function to create this object from a dict."""
38-
raise NotImplementedError
38+
return cls(**dct)
3939

4040
def as_dict(self) -> dict[str, Any]:
4141
"""MSONable-like function to create dict representation of this object."""
42-
raise NotImplementedError
42+
return self.model_dump()
4343

4444

4545
class SiteProperties(Enum):
@@ -183,35 +183,26 @@ def __str__(self):
183183
return self.name
184184

185185

186-
class LightLattice(tuple):
186+
class LatticeReplica(EmmetReplica):
187187
"""Low memory representation of a Lattice as a tuple of a 3x3 matrix."""
188188

189-
def __new__(cls, matrix):
190-
"""Overset __new__ to define new tuple instance."""
191-
lattice_matrix = np.array(matrix)
192-
if lattice_matrix.shape != (3, 3):
193-
raise ValueError("Lattice matrix must be 3x3.")
194-
return super(LightLattice, cls).__new__(
195-
cls, tuple([tuple(v) for v in lattice_matrix.tolist()])
196-
)
197-
198-
def as_dict(self) -> dict[str, list | str]:
199-
"""Define MSONable-like as_dict."""
200-
return {"@class": self.__class__, "@module": self.__module__, "matrix": self}
189+
matrix: Matrix3D = Field(
190+
description="The matrix represenation of the lattice, with a, b, and c as rows."
191+
)
201192

202193
@classmethod
203-
def from_dict(cls, dct: dict) -> Self:
204-
"""Define MSONable-like from_dict."""
205-
return cls(dct["matrix"])
194+
def from_pymatgen(cls, pmg_obj: Lattice) -> Self:
195+
"""Create a LatticeReplica from a pymatgen .Lattice."""
196+
return cls(matrix=pmg_obj.matrix)
206197

207-
def copy(self) -> Self:
208-
"""Return a new copy of LightLattice."""
209-
return LightLattice(self)
198+
def to_pymatgen(self) -> Lattice:
199+
"""Create a pymatgen .Lattice."""
200+
return Lattice(self.matrix)
210201

211202
@property
212203
def volume(self) -> float:
213204
"""Get the volume enclosed by the direct lattice vectors."""
214-
return abs(np.linalg.det(self))
205+
return abs(np.linalg.det(self.matrix))
215206

216207

217208
class ElementReplica(EmmetReplica):
@@ -294,20 +285,23 @@ def from_pymatgen(cls, pmg_obj: Element | PeriodicSite) -> Self:
294285
element=ElementSymbol(
295286
next(iter(pmg_obj.species.remove_charges().as_dict()))
296287
),
297-
lattice=LightLattice(pmg_obj.lattice.matrix),
288+
lattice=LatticeReplica.from_pymatgen(pmg_obj.lattice),
298289
frac_coords=pmg_obj.frac_coords,
299290
cart_coords=pmg_obj.coords,
300291
)
301292

302-
def to_pymatgen(self) -> PeriodicSite:
303-
"""Create a PeriodicSite from a ElementReplica."""
304-
return PeriodicSite(
305-
self.element.name,
306-
self.frac_coords,
307-
Lattice(self.lattice),
308-
coords_are_cartesian=False,
309-
properties=self.properties,
310-
)
293+
def to_pymatgen(self) -> PeriodicSite | Element:
294+
"""Create an Element or PeriodicSite from a ElementReplica."""
295+
296+
if self.lattice and self.frac_coords:
297+
return PeriodicSite(
298+
self.element.name,
299+
self.frac_coords,
300+
Lattice(self.lattice),
301+
coords_are_cartesian=False,
302+
properties=self.properties,
303+
)
304+
return Element(self.element.name)
311305

312306
@property
313307
def species(self) -> dict[str, int]:
@@ -395,7 +389,7 @@ class StructureReplica(BaseModel):
395389
396390
Parameters
397391
-----------
398-
lattice : LightLattice
392+
lattice : LatticeReplica
399393
A 3x3 tuple of the lattice vectors, with a, b, and c as subsequent rows.
400394
species : list[ElementReplica]
401395
A list of elements in the structure
@@ -408,7 +402,7 @@ class StructureReplica(BaseModel):
408402
The total charge on the structure.
409403
"""
410404

411-
lattice: LightLattice = Field(description="The lattice in 3x3 matrix form.")
405+
lattice: LatticeReplica = Field(description="The lattice in 3x3 matrix form.")
412406
species: list[ElementReplica] = Field(description="The elements in the structure.")
413407
frac_coords: ListMatrix3D = Field(
414408
description="The direct coordinates of the sites in the structure."
@@ -436,7 +430,7 @@ def __getitem__(self, idx: int | slice) -> ElementReplica | list[ElementReplica]
436430
return self.sites[idx]
437431
raise IndexError("Index must be an integer or slice!")
438432

439-
def __iter__(self) -> Iterator[ElementReplica]:
433+
def __iter__(self) -> Iterator[ElementReplica]: # type: ignore[override]
440434
"""Permit list-like iteration on the sites in StructureReplica."""
441435
yield from self.sites
442436

@@ -471,8 +465,8 @@ def from_pymatgen(cls, pmg_obj: Structure) -> Self:
471465
"Currently, `StructureReplica` is intended to represent only ordered materials."
472466
)
473467

474-
lattice = LightLattice(pmg_obj.lattice.matrix)
475-
properties = [{} for _ in range(len(pmg_obj))]
468+
lattice = LatticeReplica.from_pymatgen(pmg_obj.lattice)
469+
properties: list[dict[str, Any]] = [{} for _ in range(len(pmg_obj))]
476470
for idx, site in enumerate(pmg_obj):
477471
for k in ("charge", "magmom", "velocities", "selective_dynamics"):
478472
if (prop := site.properties.get(k)) is not None:
@@ -499,13 +493,13 @@ def from_pymatgen(cls, pmg_obj: Structure) -> Self:
499493
def to_pymatgen(self) -> Structure:
500494
"""Convert to a pymatgen .Structure."""
501495
return Structure.from_sites(
502-
[site.to_periodic_site() for site in self], charge=self.charge
496+
[site.to_pymatgen() for site in self], charge=self.charge # type: ignore[misc]
503497
)
504498

505499
@classmethod
506500
def from_poscar(cls, poscar_path: str | Path) -> Self:
507501
"""Define convenience method to create a StructureReplica from a VASP POSCAR."""
508-
return cls.from_structure(Poscar.from_file(poscar_path).structure)
502+
return cls.from_pymatgen(Poscar.from_file(poscar_path).structure)
509503

510504
def __str__(self):
511505
"""Define format for printing a Structure."""

0 commit comments

Comments
 (0)