Skip to content

Commit a1821f9

Browse files
committed
Merge remote-tracking branch 'origin/main' into runner-callbacks
2 parents bf1a2c6 + 28d4c35 commit a1821f9

33 files changed

+458
-146
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ repos:
2121
rev: 5.10.1
2222
hooks:
2323
- id: isort
24-
- repo: https://gitlab.com/pycqa/flake8
24+
- repo: https://github.com/pycqa/flake8
2525
rev: 3.9.2
2626
hooks:
2727
- id: flake8

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def peak(x, a=0.01):
7575

7676

7777
learner = Learner1D(peak, bounds=(-1, 1))
78-
runner = Runner(learner, goal=lambda l: l.loss() < 0.01)
78+
runner = Runner(learner, loss_goal=0.01)
7979
runner.live_info()
8080
runner.live_plot()
8181
```

adaptive/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from contextlib import suppress
22

3-
from adaptive import learner, runner, utils
43
from adaptive._version import __version__
54
from adaptive.learner import (
65
AverageLearner,
@@ -22,6 +21,8 @@
2221
)
2322
from adaptive.runner import AsyncRunner, BlockingRunner, Runner
2423

24+
from adaptive import learner, runner, utils # isort:skip
25+
2526
__all__ = [
2627
"learner",
2728
"runner",

adaptive/learner/data_saver.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _to_key(x):
2020
return tuple(x.values) if x.values.size > 1 else x.item()
2121

2222

23-
class DataSaver:
23+
class DataSaver(BaseLearner):
2424
"""Save extra data associated with the values that need to be learned.
2525
2626
Parameters
@@ -50,6 +50,18 @@ def new(self) -> DataSaver:
5050
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5151
return DataSaver(self.learner.new(), self.arg_picker)
5252

53+
@copy_docstring_from(BaseLearner.ask)
54+
def ask(self, *args, **kwargs):
55+
return self.learner.ask(*args, **kwargs)
56+
57+
@copy_docstring_from(BaseLearner.loss)
58+
def loss(self, *args, **kwargs):
59+
return self.learner.loss(*args, **kwargs)
60+
61+
@copy_docstring_from(BaseLearner.remove_unfinished)
62+
def remove_unfinished(self, *args, **kwargs):
63+
return self.learner.remove_unfinished(*args, **kwargs)
64+
5365
def __getattr__(self, attr: str) -> Any:
5466
return getattr(self.learner, attr)
5567

0 commit comments

Comments
 (0)