Skip to content

Commit a28b25a

Browse files
authored
More linting (#163)
* notbooks: * linting * update configs * linting * linting * linting * linting * cover * lint * lint * lint * lint * lint * lint * lint * lint * lint * lint * lint
1 parent 564210e commit a28b25a

File tree

10 files changed

+196
-44
lines changed

10 files changed

+196
-44
lines changed

pymatgen/analysis/defects/ccd.py

-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def _get_ediff(self, output_order="skb") -> npt.NDArray:
380380
rearrangement here so that we have a single point of failure.
381381
382382
Args:
383-
band_structure: The band structure of the relaxed defect calculation.
384383
output_order: The order of the output. Defaults to "skb" (spin, kpoint, band]).
385384
You can also use "bks" (band, kpoint, spin).
386385

pymatgen/analysis/defects/core.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .utils import get_plane_spacing
1919

2020
if TYPE_CHECKING:
21+
from numpy.typing import ArrayLike
2122
from pymatgen.core import Structure
2223
from pymatgen.symmetry.structure import SymmetrizedStructure
2324

@@ -312,11 +313,7 @@ def _has_oxi(struct):
312313

313314
@property
314315
def symmetrized_structure(self) -> SymmetrizedStructure:
315-
"""Returns the multiplicity of a defect site within the structure.
316-
317-
This is required for concentration analysis and confirms that defect_site is a
318-
site in bulk_structure.
319-
"""
316+
"""Get the symmetrized version of the bulk structure."""
320317
sga = SpacegroupAnalyzer(
321318
self.structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance
322319
)
@@ -895,7 +892,19 @@ def get_vacancy(structure: Structure, isite: int, **kwargs) -> Vacancy:
895892
return Vacancy(structure=structure, site=site, **kwargs)
896893

897894

898-
def _set_selective_dynamics(structure, site_pos, relax_radius):
895+
def _set_selective_dynamics(
896+
structure: Structure, site_pos: ArrayLike, relax_radius: float | str | None
897+
):
898+
"""Set the selective dynamics behavior.
899+
900+
Allow atoms to move for sites within a given radius of a given site,
901+
all other atoms are fixed. Modify the structure in place.
902+
903+
Args:
904+
structure: The structure to set the selective dynamics.
905+
site_pos: The center of the relaxation sphere.
906+
relax_radius: The radius of the relaxation sphere.
907+
"""
899908
if relax_radius is None:
900909
return
901910
if relax_radius == "auto":
@@ -974,11 +983,15 @@ def _get_mapped_sites(uc_structure: Structure, sc_structure: Structure, r=0.001)
974983
return mapped_site_indices
975984

976985

977-
def center_structure(structure, ref_fpos) -> Structure:
986+
def center_structure(structure: Structure, ref_fpos: ArrayLike) -> Structure:
978987
"""Shift the sites around a center.
979988
980989
Move all the sites in the structure so that they
981990
are in the periodic image closest to the reference fractional position.
991+
992+
Args:
993+
structure: The structure to be centered.
994+
ref_fpos: The reference fractional position that will be set to the center.
982995
"""
983996
struct = structure.copy()
984997
for idx, d_site in enumerate(struct):
@@ -997,8 +1010,7 @@ def _get_el_changes_from_structures(defect_sc: Structure, bulk_sc: Structure) ->
9971010
bulk_sc: The bulk structure.
9981011
9991012
Returns:
1000-
str: The name of the defect, if the defect is a complex, the names of the
1001-
individual defects are separated by "+".
1013+
dict: A dictionary representing the species changes in creating the defect.
10021014
"""
10031015

10041016
def _check_int(n):

pymatgen/analysis/defects/finder.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_site_groups(struct, symprec=0.01, angle_tolerance=5.0) -> List[SiteGroup
217217
return site_groups
218218

219219

220-
def get_soap_vec(struct: "Structure") -> "NDArray":
220+
def get_soap_vec(struct: "Structure") -> NDArray:
221221
"""Get the SOAP vector for each site in the structure.
222222
223223
Args:
@@ -237,17 +237,32 @@ def get_soap_vec(struct: "Structure") -> "NDArray":
237237
return vecs
238238

239239

240-
def get_site_vecs(struct: "Structure"):
241-
"""Get the SiteVec representation of each site in the structure."""
240+
def get_site_vecs(struct: Structure) -> List[SiteVec]:
241+
"""Get the SiteVec representation of each site in the structure.
242+
243+
Args:
244+
struct: Structure object to compute the site vectors (SOAP).
245+
246+
Returns:
247+
List[SiteVec]: List of SiteVec representing each site in the structure.
248+
"""
242249
vecs = get_soap_vec(struct)
243-
site_vecs = []
244-
for i, site in enumerate(struct):
245-
site_vecs.append(SiteVec(species=site.species_string, site=site, vec=vecs[i]))
246-
return site_vecs
250+
return [
251+
SiteVec(species=site.species_string, site=site, vec=vecs[i])
252+
for i, site in enumerate(struct)
253+
]
247254

248255

249256
def cosine_similarity(vec1, vec2) -> float:
250-
"""Cosine similarity between two vectors."""
257+
"""Cosine similarity between two vectors.
258+
259+
Args:
260+
vec1: First vector
261+
vec2: Second vector
262+
263+
Returns:
264+
float: Cosine similarity between the two vectors
265+
"""
251266
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
252267

253268

@@ -278,11 +293,19 @@ def best_match(sv: SiteVec, sgs: List[SiteGroup]) -> Tuple[SiteGroup, float]:
278293
return best_match, best_similarity
279294

280295

281-
def _get_broundary(arr, n_max=16, n_skip=3):
296+
def _get_broundary(arr, n_max=16, n_skip=3) -> int:
282297
"""Get the boundary index for the high-distortion indices.
283298
284299
Assuming arr is sorted in reverse order,
285300
find the biggest value drop in arr[n_skip:n_max].
301+
302+
Args:
303+
arr: List of numbers
304+
n_max: Maximum index to consider
305+
n_skip: Number of indices to skip
306+
307+
Returns:
308+
int: The boundary index
286309
"""
287310
sub_arr = np.array(arr[n_skip:n_max])
288311
diffs = sub_arr[1:] - sub_arr[:-1]
@@ -291,7 +314,7 @@ def _get_broundary(arr, n_max=16, n_skip=3):
291314

292315
def get_weighted_average_position(
293316
lattice: Lattice, frac_positions: ArrayLike, weights: ArrayLike | None = None
294-
) -> "NDArray":
317+
) -> NDArray:
295318
"""Get the weighted average position of a set of positions in frac coordinates.
296319
297320
The algorithm starts at position with the highest weight, and gradually moves

pymatgen/analysis/defects/generators.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,20 @@ def _space_group_analyzer(self, structure: Structure) -> SpacegroupAnalyzer:
4949
"This generator is using the `SpaceGroupAnalyzer` and requires `symprec` and `angle_tolerance` to be set."
5050
)
5151

52+
def generate(self, *args, **kwargs) -> Generator[Defect, None, None]:
53+
"""Generate a defect.
54+
55+
Args:
56+
*args: Additional positional arguments.
57+
**kwargs: Additional keyword arguments.
58+
59+
Returns:
60+
Generator[Defect, None, None]: Generator that yields a list of ``Defect`` objects.
61+
"""
62+
raise NotImplementedError
63+
5264
def get_defects(self, *args, **kwargs) -> list[Defect]:
53-
"""Call the generator and convert the results into a list."""
65+
"""Alias for self.generate."""
5466
return list(self.generate(*args, **kwargs))
5567

5668

@@ -254,7 +266,7 @@ def generate(
254266
insertions: The insertions to be made given as a dictionary {"Mg": [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]}.
255267
multiplicities: The multiplicities of the insertions to be made given as a dictionary {"Mg": [1, 2]}.
256268
equivalent_positions: The equivalent positions of the each inserted species given as a dictionary.
257-
Note that they should typically be the same but we allow for more flexibility.
269+
Note that they should typically be the same but we allow for more flexibility here.
258270
**kwargs: Additional keyword arguments for the ``Interstitial`` constructor.
259271
260272
Returns:
@@ -418,6 +430,8 @@ class ChargeInterstitialGenerator(InterstitialGenerator):
418430
min_dist: Minimum to atoms in the host structure
419431
avg_radius: The radius around each local minima used to evaluate the average charge.
420432
max_avg_charge: The maximum average charge to accept.
433+
max_insertions: The maximum number of insertion sites to consider.
434+
Will choose the sites with the lowest average charge.
421435
"""
422436

423437
def __init__(

pymatgen/analysis/defects/plotting/optics.py

+48-5
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,25 @@ def _plot_eigs(
183183
x_width: float = 0.3,
184184
**kwargs,
185185
) -> None:
186-
"""Plot the eigenvalues."""
186+
"""Plot the eigenvalues.
187+
188+
Args:
189+
d_eigs:
190+
The dictionary of eigenvalues for the defect state. In the format of
191+
(iband, ikpt, ispin) -> eigenvalue
192+
e_fermi:
193+
The bands above and below the Fermi level will be colored differently.
194+
If not provided, they will all be colored the same.
195+
ax:
196+
The matplotlib axis object to plot on.
197+
x0:
198+
The x coordinate of the center of the set of lines representing the eigenvalues.
199+
x_width:
200+
The width of the set of lines representing the eigenvalues.
201+
**kwargs:
202+
Keyword arguments to pass to `matplotlib.pyplot.hlines`.
203+
For example, `linestyles`, `alpha`, etc.
204+
"""
187205
if ax is None: # pragma: no cover
188206
ax = plt.gca()
189207

@@ -215,7 +233,7 @@ def _plot_matrix_elements(
215233
arrow_width=0.1,
216234
cmap=None,
217235
norm=None,
218-
):
236+
) -> tuple[list[tuple], plt.cm, plt.Normalize]:
219237
"""Plot arrow for the transition from the defect state to all other states.
220238
221239
Args:
@@ -242,13 +260,21 @@ def _plot_matrix_elements(
242260
The cartesian direction of the WAVDER tensor to sum over for the plot.
243261
If not provided, all the absolute values of the matrix for all
244262
three diagonal entries will be summed.
263+
264+
Returns:
265+
plot_data:
266+
A list of tuples in the format of (iband, ikpt, ispin, eigenvalue, matrix element)
267+
cmap:
268+
The matplotlib color map used.
269+
norm:
270+
The matplotlib normalization used.
245271
"""
246272
if ax is None: # pragma: no cover
247273
ax = plt.gca()
248274
ax.set_aspect("equal")
249275
jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys()))
250276
y0 = d_eig[jb, jkpt, jspin]
251-
plot_data = []
277+
plot_data: list[tuple] = []
252278
for (ib, ik, ispin), eig in d_eig.items():
253279
A = 0
254280
for idir, jdir in ijdirs:
@@ -289,8 +315,25 @@ def _plot_matrix_elements(
289315
return plot_data, cmap, norm
290316

291317

292-
def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame:
293-
"""Convert the eigenvalue and matrix element data into a pandas dataframe."""
318+
def _get_dataframe(d_eigs: dict, me_plot_data: list[tuple]) -> pd.DataFrame:
319+
"""Convert the eigenvalue and matrix element data into a pandas dataframe.
320+
321+
Args:
322+
d_eigs:
323+
The dictionary of eigenvalues for the defect state. In the format of
324+
(iband, ikpt, ispin) -> eigenvalue
325+
me_plot_data:
326+
A list of tuples in the format of (iband, ikpt, ispin, eigenvalue, matrix element)
327+
328+
Returns:
329+
A pandas dataframe with the following columns:
330+
ib: The band index of the state the arrow is pointing to.
331+
jb: The band index of the defect state.
332+
kpt: The kpoint index of the state the arrow is pointing to.
333+
spin: The spin index of the state the arrow is pointing to.
334+
eig: The eigenvalue of the state the arrow is pointing to.
335+
M.E.: The matrix element of the transition.
336+
"""
294337
_, ikpt, ispin = next(iter(d_eigs.keys()))
295338
df = pd.DataFrame(
296339
me_plot_data,

pymatgen/analysis/defects/plotting/phases.py

-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def _convex_hull_2d(
102102
points: list[dict],
103103
x_element: Element,
104104
y_element: Element,
105-
tol: float = 0.001,
106105
competing_phases: list = None,
107106
) -> list[dict]:
108107
"""Compute the convex hull of a set of points in 2D.

pymatgen/analysis/defects/recombination.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030

3131
@njit(cache=True)
3232
def fact(n: int) -> float: # pragma: no cover
33-
"""Compute the factorial of n."""
33+
"""Compute the factorial of n.
34+
35+
Args:
36+
n: The number to compute the factorial of.
37+
38+
Returns:
39+
The factorial of n.
40+
"""
3441
if n > 20:
3542
return LOOKUP_TABLE[-1] * np.prod(
3643
np.array(list(range(21, n + 1)), dtype=np.double)
@@ -40,7 +47,15 @@ def fact(n: int) -> float: # pragma: no cover
4047

4148
@njit(cache=True)
4249
def herm(x: float, n: int) -> float: # pragma: no cover
43-
"""Recursive definition of hermite polynomial."""
50+
"""Recursive definition of hermite polynomial.
51+
52+
Args:
53+
x: The value to evaluate the hermite polynomial at.
54+
n: The order of the hermite polynomial.
55+
56+
Returns:
57+
The value of the hermite polynomial at x.
58+
"""
4459
if n == 0:
4560
return 1.0
4661
if n == 1:

pymatgen/analysis/defects/supercells.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# from pymatgen.io.ase import AseAtomsAdaptor
1313

1414
if TYPE_CHECKING:
15-
import numpy as np
15+
from numpy.typing import ArrayLike, NDArray
1616
from pymatgen.core import Structure
1717

1818
__author__ = "Jimmy-Xuan Shen"
@@ -29,7 +29,7 @@ def get_sc_fromstruct(
2929
max_atoms: int = 240,
3030
min_length: float = 10.0,
3131
force_diagonal: bool = False,
32-
) -> np.ndarray | np.array | None:
32+
) -> NDArray | ArrayLike | None:
3333
"""Generate the best supercell from a unitcell.
3434
3535
The CubicSupercellTransformation from PMG is much faster but don't iterate over as
@@ -92,7 +92,7 @@ def _cubic_cell(
9292
max_atoms: int = 240,
9393
min_length: float = 10.0,
9494
force_diagonal: bool = False,
95-
) -> np.ndarray | None:
95+
) -> NDArray | None:
9696
"""Generate the best supercell from a unit cell.
9797
9898
This is done using the pymatgen CubicSupercellTransformation class.
@@ -125,23 +125,35 @@ def _cubic_cell(
125125
return cst.transformation_matrix
126126

127127

128-
def _ase_cubic(base_struture, min_atoms: int = 80, max_atoms: int = 240):
128+
def _ase_cubic(base_structure, min_atoms: int = 80, max_atoms: int = 240):
129+
"""Generate the best supercell from a unit cell.
130+
131+
Use ASE's find_optimal_cell_shape function to find the best supercell.
132+
133+
Args:
134+
base_structure: structure of the unit cell
135+
max_atoms: Maximum number of atoms allowed in the supercell.
136+
min_atoms: Minimum number of atoms allowed in the supercell.
137+
138+
Returns:
139+
3x3 matrix: supercell matrix
140+
"""
129141
from ase.build import find_optimal_cell_shape, get_deviation_from_optimal_cell_shape
130142
from pymatgen.io.ase import AseAtomsAdaptor
131143

132144
_logger.warn("ASE cubic supercell generation.")
133145

134146
aaa = AseAtomsAdaptor()
135-
ase_atoms = aaa.get_atoms(base_struture)
136-
lower = math.ceil(min_atoms / base_struture.num_sites)
137-
upper = math.floor(max_atoms / base_struture.num_sites)
147+
ase_atoms = aaa.get_atoms(base_structure)
148+
lower = math.ceil(min_atoms / base_structure.num_sites)
149+
upper = math.floor(max_atoms / base_structure.num_sites)
138150
min_dev = (float("inf"), None)
139151
for size in range(lower, upper + 1):
140152
_logger.warn(f"Trying size {size} out of {upper}.")
141153
sc = find_optimal_cell_shape(
142154
ase_atoms.cell, target_size=size, target_shape="sc"
143155
)
144-
sc_cell = aaa.get_atoms(base_struture * sc).cell
156+
sc_cell = aaa.get_atoms(base_structure * sc).cell
145157
deviation = get_deviation_from_optimal_cell_shape(sc_cell, target_shape="sc")
146158
min_dev = min(min_dev, (deviation, sc))
147159
if min_dev[1] is None:

0 commit comments

Comments
 (0)