Skip to content

Commit 8213bef

Browse files
committed
Add fixture for fake measurements
1 parent 47d37cf commit 8213bef

4 files changed

+36
-29
lines changed

tests/conftest.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@
7777
)
7878
from baybe.utils.basic import hilberts_factory
7979
from baybe.utils.boolean import strtobool
80-
from baybe.utils.dataframe import add_fake_measurements, add_parameter_noise
80+
from baybe.utils.dataframe import (
81+
add_fake_measurements,
82+
add_parameter_noise,
83+
create_fake_input,
84+
)
8185
from baybe.utils.random import temporary_seed
8286

8387
# Hypothesis settings
@@ -164,6 +168,18 @@ def fixture_batch_size(request):
164168
return request.param
165169

166170

171+
@pytest.fixture(name="n_fake_measurements")
172+
def fixture_n_fake_measurements(batch_size):
173+
"""Number of rows for :func:`baybe.utils.dataframe.create_fake_input`."""
174+
return batch_size
175+
176+
177+
@pytest.fixture(name="fake_measurements")
178+
def fixture_fake_measurements(parameters, targets, batch_size):
179+
"""Artificially created valid measurements."""
180+
return create_fake_input(parameters, targets, batch_size)
181+
182+
167183
@pytest.fixture(
168184
params=[5, pytest.param(8, marks=pytest.mark.slow)],
169185
name="n_grid_points",

tests/test_input_output.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from baybe.recommenders import BotorchRecommender
99
from baybe.searchspace import SearchSpace
1010
from baybe.targets import NumericalTarget
11-
from baybe.utils.dataframe import add_fake_measurements, create_fake_input
11+
from baybe.utils.dataframe import add_fake_measurements
1212

1313

1414
@pytest.mark.parametrize(
@@ -43,14 +43,12 @@
4343
],
4444
)
4545
@pytest.mark.parametrize("n_grid_points", [5], ids=["g5"])
46-
def test_bad_parameter_input_value(campaign, bad_val, batch_size):
46+
def test_bad_parameter_input_value(campaign, bad_val, fake_measurements):
4747
"""Test attempting to read in an invalid parameter value."""
48-
rec = create_fake_input(campaign.parameters, campaign.targets, batch_size)
49-
5048
# Add an invalid value
51-
rec[campaign.parameters[0].name].iloc[0] = bad_val
49+
fake_measurements[campaign.parameters[0].name].iloc[0] = bad_val
5250
with pytest.raises((ValueError, TypeError)):
53-
campaign.add_measurements(rec)
51+
campaign.add_measurements(fake_measurements)
5452

5553

5654
@pytest.mark.parametrize(

tests/test_pending_experiments.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from baybe.utils.basic import get_subclasses
2222
from baybe.utils.dataframe import (
2323
add_parameter_noise,
24-
create_fake_input,
2524
)
2625
from baybe.utils.random import temporary_seed
2726

@@ -115,13 +114,12 @@
115114
],
116115
)
117116
@pytest.mark.parametrize("n_grid_points", [8], ids=["grid8"])
118-
def test_pending_points(campaign, batch_size):
117+
def test_pending_points(campaign, batch_size, fake_measurements):
119118
"""Test there is no recommendation overlap if pending experiments are specified."""
120119
warnings.filterwarnings("ignore", category=UnusedObjectWarning)
121120

122121
# Add some initial measurements
123-
rec = create_fake_input(campaign.parameters, campaign.targets, batch_size)
124-
campaign.add_measurements(rec)
122+
campaign.add_measurements(fake_measurements)
125123

126124
# Get recommendations and set them as pending experiments while getting another set
127125
# Fix the random seed for each recommend call to limit influence of randomness in
@@ -156,24 +154,23 @@ def test_pending_points(campaign, batch_size):
156154
)
157155
@pytest.mark.parametrize("n_grid_points", [5], ids=["g5"])
158156
@pytest.mark.parametrize("batch_size", [3], ids=["b3"])
159-
def test_invalid_acqf(searchspace, recommender, objective, batch_size, acqf):
157+
def test_invalid_acqf(searchspace, objective, batch_size, acqf, fake_measurements):
160158
"""Test exception raised for acqfs that don't support pending experiments."""
161159
recommender = TwoPhaseMetaRecommender(
162160
recommender=BotorchRecommender(acquisition_function=acqf)
163161
)
164162

165163
# Create fake measurements and pending experiments
166-
rec1 = create_fake_input(searchspace.parameters, objective.targets, batch_size)
167-
rec2 = rec1.copy()
168-
add_parameter_noise(rec2, searchspace.parameters)
164+
fake_pending_experiments = fake_measurements.copy()
165+
add_parameter_noise(fake_pending_experiments, searchspace.parameters)
169166

170167
with pytest.raises(IncompatibleAcquisitionFunctionError):
171168
recommender.recommend(
172169
batch_size,
173170
searchspace,
174171
objective,
175-
measurements=rec1,
176-
pending_experiments=rec2,
172+
measurements=fake_measurements,
173+
pending_experiments=fake_pending_experiments,
177174
)
178175

179176

@@ -215,18 +212,18 @@ def test_invalid_input(
215212
batch_size,
216213
invalid_pending_value,
217214
parameter_names,
215+
fake_measurements,
218216
):
219217
"""Test exception raised for invalid pending experiments input."""
220218
# Create fake measurements and pending experiments
221-
rec1 = create_fake_input(searchspace.parameters, objective.targets, batch_size)
222-
rec2 = rec1.copy()
223-
rec2[parameter_names[0]] = invalid_pending_value
219+
fake_pending_experiments = fake_measurements.copy()
220+
fake_pending_experiments[parameter_names[0]] = invalid_pending_value
224221

225222
with pytest.raises((ValueError, TypeError), match="parameter"):
226223
recommender.recommend(
227224
batch_size,
228225
searchspace,
229226
objective,
230-
measurements=rec1,
231-
pending_experiments=rec2,
227+
measurements=fake_measurements,
228+
pending_experiments=fake_pending_experiments,
232229
)

tests/test_surrogate.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,21 @@
22

33
from unittest.mock import patch
44

5-
from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender
65
from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate
7-
from baybe.utils.dataframe import add_fake_measurements
86

97

108
@patch.object(GaussianProcessSurrogate, "_fit")
11-
def test_caching(patched, searchspace, objective):
9+
def test_caching(patched, searchspace, objective, fake_measurements):
1210
"""A second fit call with the same context does not trigger retraining."""
1311
# Prepare the setting
14-
measurements = RandomRecommender().recommend(3, searchspace, objective)
15-
add_fake_measurements(measurements, objective.targets)
1612
surrogate = GaussianProcessSurrogate()
1713

1814
# First call
19-
surrogate.fit(searchspace, objective, measurements)
15+
surrogate.fit(searchspace, objective, fake_measurements)
2016
patched.assert_called()
2117

2218
patched.reset_mock()
2319

2420
# Second call
25-
surrogate.fit(searchspace, objective, measurements)
21+
surrogate.fit(searchspace, objective, fake_measurements)
2622
patched.assert_not_called()

0 commit comments

Comments
 (0)