Skip to content

Commit 1276fee

Browse files
committed
Add and update tests in test_bounds.py
1 parent a57ae06 commit 1276fee

2 files changed

Lines changed: 15 additions & 4 deletions

File tree

src/optimagic/parameters/tree_conversion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ def get_tree_converter(
4848
_registry = get_registry(extended=True)
4949
_params_vec, _params_treedef = tree_flatten(params, registry=_registry)
5050
_params_vec = np.array(_params_vec).astype(float)
51+
# We cannot propagate None-valued bounds until the conversion and converter code is
52+
# updated to handle None bounds.
5153
_lower, _upper = get_internal_bounds(
5254
params=params,
5355
bounds=bounds,
5456
registry=_registry,
57+
propagate_none=False,
5558
)
5659

5760
if add_soft_bounds:
@@ -60,6 +63,7 @@ def get_tree_converter(
6063
bounds=bounds,
6164
registry=_registry,
6265
add_soft_bounds=add_soft_bounds,
66+
propagate_none=True,
6367
)
6468
else:
6569
_soft_lower, _soft_upper = None, None

tests/optimagic/parameters/test_bounds.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,15 @@ def test_get_bounds_with_upper_bounds(pytree_params):
120120
assert_array_equal(got_upper, expected_upper)
121121

122122

123-
def test_get_bounds_numpy(array_params):
124-
got_lower, got_upper = get_internal_bounds(array_params)
123+
def test_get_bounds_numpy_propagate_none_true(array_params):
124+
got_lower, got_upper = get_internal_bounds(array_params, propagate_none=True)
125+
126+
assert got_lower is None
127+
assert got_upper is None
128+
129+
130+
def test_get_bounds_numpy_propagate_none_false(array_params):
131+
got_lower, got_upper = get_internal_bounds(array_params, propagate_none=False)
125132

126133
expected = np.array([np.inf, np.inf])
127134

@@ -139,7 +146,7 @@ def test_get_bounds_numpy_error(array_params):
139146
)
140147

141148

142-
def test_get_fast_path_bounds_none_propagate_true():
149+
def test_get_fast_path_bounds_propagate_none_true():
143150
got_lower, got_upper = _get_fast_path_bounds(
144151
params=np.array([1, 2, 3]),
145152
bounds=Bounds(lower=None, upper=None),
@@ -149,7 +156,7 @@ def test_get_fast_path_bounds_none_propagate_true():
149156
assert got_upper is None
150157

151158

152-
def test_get_fast_path_bounds_none_propagate_false():
159+
def test_get_fast_path_bounds_propagate_none_false():
153160
got_lower, got_upper = _get_fast_path_bounds(
154161
params=np.array([1, 2, 3]),
155162
bounds=Bounds(lower=None, upper=None),

0 commit comments

Comments
 (0)