Skip to content

Commit ad2ee66

Browse files
committed
Mypy fixes
1 parent 97ceea9 commit ad2ee66

File tree

5 files changed

+152
-391
lines changed

5 files changed

+152
-391
lines changed

emmet-builders/emmet/builders/materials/elasticity.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def ensure_index(self):
7373
self.elasticity.ensure_index("material_id")
7474
self.elasticity.ensure_index("last_updated")
7575

76-
def get_items(
77-
self,
78-
) -> Generator[Tuple[str, Dict[str, str], List[Dict]], None, None]:
76+
def get_items(self,) -> Generator[Tuple[str, Dict[str, str], List[Dict]], None, None]:
7977
"""
8078
Gets all items to process into elasticity docs.
8179
@@ -89,9 +87,7 @@ def get_items(
8987

9088
self.ensure_index()
9189

92-
cursor = self.materials.query(
93-
criteria=self.query, properties=["material_id", "calc_types", "task_ids"]
94-
)
90+
cursor = self.materials.query(criteria=self.query, properties=["material_id", "calc_types", "task_ids"])
9591

9692
# query for tasks
9793
query = self.query.copy()
@@ -120,9 +116,7 @@ def get_items(
120116

121117
yield material_id, calc_types, tasks
122118

123-
def process_item(
124-
self, item: Tuple[MPID, Dict[str, str], List[Dict]]
125-
) -> Union[Dict, None]:
119+
def process_item(self, item: Tuple[MPID, Dict[str, str], List[Dict]]) -> Union[Dict, None]:
126120
"""
127121
Process all tasks belong to the same material into an elasticity doc.
128122
@@ -161,23 +155,15 @@ def process_item(
161155

162156
# select one task for each set of optimization tasks with the same lattice
163157
opt_grouped_tmp = group_by_parent_lattice(opt_tasks, mode="opt")
164-
opt_grouped = [
165-
(lattice, filter_opt_tasks_by_time(tasks, self.logger))
166-
for lattice, tasks in opt_grouped_tmp
167-
]
158+
opt_grouped = [(lattice, filter_opt_tasks_by_time(tasks, self.logger)) for lattice, tasks in opt_grouped_tmp]
168159

169160
# for deformed tasks with the same lattice, select one if there are multiple
170161
# tasks with the same deformation
171162
deform_grouped = group_by_parent_lattice(deform_tasks, mode="deform")
172-
deform_grouped = [
173-
(lattice, filter_deform_tasks_by_time(tasks))
174-
for lattice, tasks in deform_grouped
175-
]
163+
deform_grouped = [(lattice, filter_deform_tasks_by_time(tasks)) for lattice, tasks in deform_grouped]
176164

177165
# select opt and deform tasks for fitting
178-
final_opt, final_deform = select_final_opt_deform_tasks(
179-
opt_grouped, deform_grouped, self.logger
180-
)
166+
final_opt, final_deform = select_final_opt_deform_tasks(opt_grouped, deform_grouped, self.logger)
181167
if final_opt is None or final_deform is None:
182168
return None
183169

@@ -187,11 +173,7 @@ def process_item(
187173
deform_task_ids = []
188174
deform_dir_names = []
189175
for doc in final_deform:
190-
deforms.append(
191-
Deformation(
192-
doc["transmuter"]["transformation_params"][0]["deformation"]
193-
)
194-
)
176+
deforms.append(Deformation(doc["transmuter"]["transformation_params"][0]["deformation"]))
195177
# -0.1 to convert to GPa from kBar and s
196178
stresses.append(-0.1 * Stress(doc["output"]["stress"]))
197179
deform_task_ids.append(doc["task_id"])
@@ -226,9 +208,7 @@ def update_targets(self, items: List[Dict]):
226208

227209

228210
def filter_opt_tasks(
229-
tasks: List[Dict],
230-
calc_types: Dict[str, str],
231-
target_calc_type: str = CalcType.GGA_Structure_Optimization,
211+
tasks: List[Dict], calc_types: Dict[str, str], target_calc_type: str = CalcType.GGA_Structure_Optimization,
232212
) -> List[Dict]:
233213
"""
234214
Filter optimization tasks, by
@@ -240,9 +220,7 @@ def filter_opt_tasks(
240220

241221

242222
def filter_deform_tasks(
243-
tasks: List[Dict],
244-
calc_types: Dict[str, str],
245-
target_calc_type: str = CalcType.GGA_Deformation,
223+
tasks: List[Dict], calc_types: Dict[str, str], target_calc_type: str = CalcType.GGA_Deformation,
246224
) -> List[Dict]:
247225
"""
248226
Filter deformation tasks, by
@@ -254,18 +232,13 @@ def filter_deform_tasks(
254232
for t in tasks:
255233
if calc_types[str(t["task_id"])] == target_calc_type:
256234
transforms = t["transmuter"]["transformations"]
257-
if (
258-
len(transforms) == 1
259-
and transforms[0] == "DeformStructureTransformation"
260-
):
235+
if len(transforms) == 1 and transforms[0] == "DeformStructureTransformation":
261236
deform_tasks.append(t)
262237

263238
return deform_tasks
264239

265240

266-
def filter_by_incar_settings(
267-
tasks: List[Dict], incar_settings: Dict[str, Any] = None
268-
) -> List[Dict]:
241+
def filter_by_incar_settings(tasks: List[Dict], incar_settings: Optional[Dict[str, Any]] = None) -> List[Dict]:
269242
"""
270243
Filter tasks by incar parameters.
271244
"""
@@ -338,9 +311,7 @@ def filter_opt_tasks_by_time(tasks: List[Dict], logger) -> Dict:
338311
return selected
339312

340313

341-
def filter_deform_tasks_by_time(
342-
tasks: List[Dict], deform_comp_tol: float = 1e-5
343-
) -> List[Dict]:
314+
def filter_deform_tasks_by_time(tasks: List[Dict], deform_comp_tol: float = 1e-5) -> List[Dict]:
344315
"""
345316
For deformation tasks with the same deformation, select the latest completed one.
346317
@@ -416,10 +387,7 @@ def select_final_opt_deform_tasks(
416387
tasks.extend(pair[1])
417388

418389
ids = [t["task_id"] for t in tasks]
419-
logger.warning(
420-
f"Cannot find optimization and deformation tasks that match by lattice "
421-
f"for tasks {ids}"
422-
)
390+
logger.warning(f"Cannot find optimization and deformation tasks that match by lattice " f"for tasks {ids}")
423391

424392
final_opt_task = None
425393
final_deform_tasks = None

0 commit comments

Comments
 (0)