Skip to content

Commit ae7ad4b

Browse files
committed
formatting
1 parent aaa10fa commit ae7ad4b

File tree

7 files changed

+34
-31
lines changed

7 files changed

+34
-31
lines changed

psiflow/data/dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Dataset:
4343
4444
This class provides methods for manipulating and analyzing collections of atomic structures.
4545
"""
46+
4647
extxyz: psiflow._DataFuture
4748

4849
def __init__(
@@ -415,6 +416,7 @@ def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]:
415416
Note:
416417
This function is wrapped as a Parsl app and executed using the default_threads executor.
417418
"""
419+
418420
def pad_arrays(
419421
arrays: list[np.ndarray],
420422
pad_dimension: int = 1,
@@ -621,6 +623,7 @@ class Computable:
621623
outputs (ClassVar[tuple[str, ...]]): Names of output quantities.
622624
batch_size (ClassVar[Optional[int]]): Default batch size for computation.
623625
"""
626+
624627
outputs: ClassVar[tuple[str, ...]] = ()
625628
batch_size: ClassVar[Optional[int]] = None
626629

psiflow/execution.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def from_config(
131131
provider_cls = SlurmProvider
132132
provider_kwargs = kwargs.pop("slurm") # do not allow empty dict
133133
provider_kwargs["init_blocks"] = 0
134-
if not 'exclusive' in provider_kwargs:
135-
provider_kwargs['exclusive'] = False
134+
if "exclusive" not in provider_kwargs:
135+
provider_kwargs["exclusive"] = False
136136
else:
137137
provider_cls = LocalProvider # noqa: F405
138138
provider_kwargs = kwargs.pop("local", {})
@@ -452,7 +452,7 @@ def from_config(
452452
max_idletime: float = 20,
453453
internal_tasks_max_threads: int = 10,
454454
default_threads: int = 4,
455-
htex_address: str = '127.0.0.1',
455+
htex_address: str = "127.0.0.1",
456456
zip_staging: Optional[bool] = None,
457457
container_uri: Optional[str] = None,
458458
container_engine: str = "apptainer",
@@ -552,11 +552,14 @@ def from_config(
552552
context = ExecutionContext(config, definitions, path / "context_dir")
553553

554554
if make_symlinks:
555-
src, dest = Path.cwd() / f'psiflow_log', path / 'parsl.log'
555+
src, dest = Path.cwd() / "psiflow_log", path / "parsl.log"
556556
_create_symlink(src, dest)
557-
src, dest = Path.cwd() / f'psiflow_submit_scripts', path / '000' / 'submit_scripts'
557+
src, dest = (
558+
Path.cwd() / "psiflow_submit_scripts",
559+
path / "000" / "submit_scripts",
560+
)
558561
_create_symlink(src, dest, is_dir=True)
559-
src, dest = Path.cwd() / f'psiflow_task_logs', path / '000' / 'task_logs'
562+
src, dest = Path.cwd() / "psiflow_task_logs", path / "000" / "task_logs"
560563
_create_symlink(src, dest, is_dir=True)
561564

562565
return context
@@ -684,5 +687,3 @@ def _create_symlink(src: Path, dest: Path, is_dir: bool = False) -> None:
684687
else:
685688
dest.touch(exist_ok=True)
686689
src.symlink_to(dest, target_is_directory=is_dir)
687-
688-

psiflow/geometry.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -274,26 +274,30 @@ def from_string(cls, s: str, natoms: Optional[int] = None) -> Optional[Geometry]
274274

275275
# read and format per_atom data
276276
column_indices = {}
277-
if 'Properties' in comment_dict:
278-
properties = comment_dict['Properties'].split(':')
277+
if "Properties" in comment_dict:
278+
properties = comment_dict["Properties"].split(":")
279279
count = 0
280280
for i in range(len(properties) // 3):
281281
name = properties[3 * i]
282282
ncolumns = int(properties[3 * i + 2])
283283
column_indices[name] = count
284284
count += ncolumns
285-
assert 'pos' in column_indices # positions need to be there
285+
assert "pos" in column_indices # positions need to be there
286286

287287
per_atom = np.recarray(natoms, dtype=per_atom_dtype)
288288
per_atom.forces[:] = np.nan
289-
POS_INDEX = column_indices.get('pos', 1)
290-
FORCES_INDEX = column_indices.get('forces', None)
289+
POS_INDEX = column_indices.get("pos", 1)
290+
FORCES_INDEX = column_indices.get("forces", None)
291291
for i in range(natoms):
292292
values = lines[i + 1].split()
293293
per_atom.numbers[i] = chemical_symbols.index(values[0])
294-
per_atom.positions[i, :] = [float(_) for _ in values[POS_INDEX:POS_INDEX + 3]]
294+
per_atom.positions[i, :] = [
295+
float(_) for _ in values[POS_INDEX : POS_INDEX + 3]
296+
]
295297
if FORCES_INDEX is not None:
296-
per_atom.forces[i, :] = [float(_) for _ in values[FORCES_INDEX:FORCES_INDEX + 3]]
298+
per_atom.forces[i, :] = [
299+
float(_) for _ in values[FORCES_INDEX : FORCES_INDEX + 3]
300+
]
297301

298302
order = {}
299303
for key, value in comment_dict.items():

psiflow/learning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from psiflow.models import Model
1818
from psiflow.reference import Reference, evaluate
1919
from psiflow.sampling import SimulationOutput, Walker, sample
20-
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i, isnan
20+
from psiflow.utils.apps import boolean_or, isnan, setup_logger, unpack_i
2121

2222
logger = setup_logger(__name__)
2323

psiflow/reference/_cp2k.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def set_global_section(cp2k_input_dict: dict, properties: tuple):
6868
global_dict = cp2k_input_dict["global"]
6969

7070
# override low/silent print levels
71-
level = global_dict.pop('print_level', 'MEDIUM')
72-
if level in ['SILENT', 'LOW']:
73-
global_dict['print_level'] = 'MEDIUM'
71+
level = global_dict.pop("print_level", "MEDIUM")
72+
if level in ["SILENT", "LOW"]:
73+
global_dict["print_level"] = "MEDIUM"
7474

7575
if properties == ("energy",):
7676
global_dict["run_type"] = "ENERGY"
@@ -156,11 +156,11 @@ def _prepare_input(
156156
if "forces" in properties:
157157
cp2k_input_dict["force_eval"]["print"] = {"FORCES": {}}
158158
cp2k_input_str = dict_to_str(cp2k_input_dict)
159-
with open(outputs[0], 'w') as f:
159+
with open(outputs[0], "w") as f:
160160
f.write(cp2k_input_str)
161161

162162

163-
prepare_input = python_app(_prepare_input, executors=['default_threads'])
163+
prepare_input = python_app(_prepare_input, executors=["default_threads"])
164164

165165

166166
# typeguarding for some reason incompatible with WQ
@@ -175,14 +175,9 @@ def cp2k_singlepoint_pre(
175175
cd_command = "cd $mytmpdir"
176176
cp_command = "cp {} cp2k.inp".format(inputs[0].filepath)
177177

178-
command_list = [
179-
tmp_command,
180-
cd_command,
181-
cp_command,
182-
cp2k_command
183-
]
178+
command_list = [tmp_command, cd_command, cp_command, cp2k_command]
184179

185-
return ' && '.join(command_list)
180+
return " && ".join(command_list)
186181

187182

188183
@typeguard.typechecked
@@ -242,7 +237,7 @@ def wrapped_app_pre(geometry, stdout: str, stderr: str):
242237
geometry,
243238
cp2k_input_dict=self.cp2k_input_dict,
244239
properties=tuple(self.outputs),
245-
outputs=[psiflow.context().new_file('cp2k_', '.inp')],
240+
outputs=[psiflow.context().new_file("cp2k_", ".inp")],
246241
)
247242
return app_pre(
248243
cp2k_command=cp2k_command,

psiflow/reference/reference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _nan_if_unsuccessful(
5555
return result
5656

5757

58-
nan_if_unsuccessful = python_app(_nan_if_unsuccessful, executors=['default_threads'])
58+
nan_if_unsuccessful = python_app(_nan_if_unsuccessful, executors=["default_threads"])
5959

6060

6161
@join_app

psiflow/utils/apps.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ def _isnan(a: Union[float, np.ndarray]) -> bool:
141141
return bool(np.any(np.isnan(a)))
142142

143143

144-
isnan = python_app(_isnan, executors=['default_threads'])
144+
isnan = python_app(_isnan, executors=["default_threads"])

0 commit comments

Comments
 (0)