Skip to content

Commit 597b481

Browse files
committed
switch to float64 dtype for functions/geometry
1 parent 6f5149a commit 597b481

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

configs/threadpool.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
parsl_log_level: DEBUG
2+
parsl_log_level: WARNING
33
retries: 0
44
ModelEvaluation:
55
gpu: false

psiflow/geometry.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
per_atom_dtype = np.dtype(
1616
[
1717
("numbers", np.uint8),
18-
("positions", np.float32, (3,)),
19-
("forces", np.float32, (3,)),
18+
("positions", np.float64, (3,)),
19+
("forces", np.float64, (3,)),
2020
]
2121
)
2222

@@ -97,7 +97,7 @@ def __init__(
9797
identifier (Optional[int], optional): Unique identifier for the geometry. Defaults to None.
9898
"""
9999
self.per_atom = per_atom.astype(per_atom_dtype) # copies data
100-
self.cell = cell.astype(np.float32)
100+
self.cell = cell.astype(np.float64)
101101
assert self.cell.shape == (3, 3)
102102
if order is None:
103103
order = {}
@@ -613,28 +613,28 @@ def create_outputs(quantities: list[str], data: list[Geometry]) -> list[np.ndarr
613613
arrays = []
614614
for quantity in quantities:
615615
if quantity in ["positions", "forces"]:
616-
array = np.empty((nframes, max_natoms, 3), dtype=np.float32)
616+
array = np.empty((nframes, max_natoms, 3), dtype=np.float64)
617617
array[:] = np.nan
618618
elif quantity in ["cell", "stress"]:
619-
array = np.empty((nframes, 3, 3), dtype=np.float32)
619+
array = np.empty((nframes, 3, 3), dtype=np.float64)
620620
array[:] = np.nan
621621
elif quantity in ["numbers"]:
622622
array = np.empty((nframes, max_natoms), dtype=np.uint8)
623623
array[:] = 0
624624
elif quantity in ["energy", "delta", "per_atom_energy"]:
625-
array = np.empty((nframes,), dtype=np.float32)
625+
array = np.empty((nframes,), dtype=np.float64)
626626
array[:] = np.nan
627627
elif quantity in ["phase"]:
628628
array = np.empty((nframes,), dtype=(np.unicode_, max_phase))
629629
array[:] = ""
630630
elif quantity in ["logprob"]:
631-
array = np.empty((nframes, nprob), dtype=np.float32)
631+
array = np.empty((nframes, nprob), dtype=np.float64)
632632
array[:] = np.nan
633633
elif quantity in ["identifier"]:
634-
array = np.empty((nframes,), dtype=np.int32)
634+
array = np.empty((nframes,), dtype=np.int64)
635635
array[:] = -1
636636
elif quantity in order_names:
637-
array = np.empty((nframes,), dtype=np.float32)
637+
array = np.empty((nframes,), dtype=np.float64)
638638
array[:] = np.nan
639639
else:
640640
raise AssertionError("missing quantity in if/else")

psiflow/utils/io.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _dump_json(
109109

110110
def convert_to_list(array):
111111
if not type(array) is np.ndarray:
112-
if type(array) is np.float32:
112+
if type(array) is np.floating:
113113
return float(array)
114114
return array
115115
as_list = []
@@ -121,7 +121,7 @@ def convert_to_list(array):
121121
value = kwargs[name]
122122
if type(value) is np.ndarray:
123123
value = convert_to_list(value)
124-
elif type(value) is np.float32:
124+
elif type(value) is np.floating:
125125
value = float(value)
126126
kwargs[name] = value
127127
with open(outputs[0], "w") as f:

tests/test_function.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22

33
import numpy as np
4-
from ase.units import kJ, mol
5-
from parsl.data_provider.files import File
4+
from ase.units import kJ, mol # type: ignore
5+
from parsl.data_provider.files import File # type: ignore
66

77
import psiflow
88
from psiflow.functions import (

0 commit comments

Comments
 (0)