Skip to content

Commit 75158ad

Browse files
committed
make sure failed walkers / failed single point evaluations still show up as reset in walker table
1 parent ee8d5b5 commit 75158ad

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

psiflow/learning.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from psiflow.models import Model
1818
from psiflow.reference import Reference, evaluate
1919
from psiflow.sampling import SimulationOutput, Walker, sample
20-
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i
20+
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i, isnan
2121

2222
logger = setup_logger(__name__)
2323

@@ -80,7 +80,11 @@ def evaluate_outputs(
8080
errors[i],
8181
np.array(error_thresholds_for_reset, dtype=float),
8282
)
83-
reset = boolean_or(error_discard, error_reset)
83+
reset = boolean_or(
84+
error_discard,
85+
error_reset,
86+
isnan(errors[i]),
87+
)
8488

8589
_ = assign_identifier(state, identifier, error_discard)
8690
assigned = unpack_i(_, 0)

psiflow/utils/apps.py

+8
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,11 @@ def _concatenate(*arrays: np.ndarray) -> np.ndarray:
134134

135135

136136
concatenate = python_app(_concatenate, executors=["default_threads"])
137+
138+
139+
@typeguard.typechecked
140+
def _isnan(a: Union[float, np.ndarray]) -> bool:
141+
return bool(np.any(np.isnan(a)))
142+
143+
144+
isnan = python_app(_isnan, executors=['default_threads'])

tests/test_learning.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ def test_evaluate_outputs(dataset):
158158
outputs[3].state = new_nullstate()
159159
outputs[7].status = 2 # should be null state
160160

161-
# resets for 3 and 7 happen in sample() method, not in evaluate_outputs!
162-
163161
identifier = 3
164162
identifier, data, resets = evaluate_outputs(
165163
outputs,
@@ -185,11 +183,7 @@ def test_evaluate_outputs(dataset):
185183
error_thresholds_for_discard=[0.0, 0.0],
186184
metrics=Metrics(),
187185
)
188-
for i in range(10):
189-
if i not in [3, 7]:
190-
assert resets[i].result() # already reset
191-
else:
192-
assert not resets[i].result()
186+
assert all([r.result() for r in resets])
193187

194188

195189
def test_wandb():

0 commit comments

Comments
 (0)