6
6
7
7
import itertools
8
8
import warnings
9
- from inspect import signature
10
9
from itertools import product
11
10
from unittest import mock
12
11
@@ -114,10 +113,8 @@ class TestOptimizeAcqf(BotorchTestCase):
114
113
@mock .patch ("botorch.generation.gen.gen_candidates_torch" )
115
114
@mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
116
115
@mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
117
- @mock .patch ("botorch.optim.utils.common.signature" )
118
116
def test_optimize_acqf_joint (
119
117
self ,
120
- mock_signature ,
121
118
mock_gen_candidates_scipy ,
122
119
mock_gen_batch_initial_conditions ,
123
120
mock_gen_candidates_torch ,
@@ -134,10 +131,6 @@ def test_optimize_acqf_joint(
134
131
mock_gen_candidates_scipy ,
135
132
mock_gen_candidates_torch ,
136
133
):
137
- if mock_gen_candidates == mock_gen_candidates_torch :
138
- mock_signature .return_value = signature (gen_candidates_torch )
139
- else :
140
- mock_signature .return_value = signature (gen_candidates_scipy )
141
134
142
135
mock_gen_batch_initial_conditions .return_value = torch .zeros (
143
136
num_restarts , q , 3 , device = self .device , dtype = dtype
@@ -264,12 +257,14 @@ def test_optimize_acqf_joint(
264
257
)
265
258
266
259
@mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
267
- @mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
268
- @mock .patch ("botorch.generation.gen.gen_candidates_torch" )
269
- @mock .patch ("botorch.optim.utils.common.signature" )
260
+ @mock .patch (
261
+ "botorch.optim.optimize.gen_candidates_scipy" , wraps = gen_candidates_scipy
262
+ )
263
+ @mock .patch (
264
+ "botorch.generation.gen.gen_candidates_torch" , wraps = gen_candidates_torch
265
+ )
270
266
def test_optimize_acqf_sequential (
271
267
self ,
272
- mock_signature ,
273
268
mock_gen_candidates_torch ,
274
269
mock_gen_candidates_scipy ,
275
270
mock_gen_batch_initial_conditions ,
@@ -278,11 +273,6 @@ def test_optimize_acqf_sequential(
278
273
for mock_gen_candidates , timeout_sec in product (
279
274
[mock_gen_candidates_scipy , mock_gen_candidates_torch ], [None , 1e-4 ]
280
275
):
281
- if mock_gen_candidates == mock_gen_candidates_torch :
282
- mock_signature .return_value = signature (gen_candidates_torch )
283
- else :
284
- mock_signature .return_value = signature (gen_candidates_scipy )
285
- mock_gen_candidates .__name__ = "gen_candidates"
286
276
q = 3
287
277
num_restarts = 2
288
278
raw_samples = 10
@@ -1019,16 +1009,12 @@ def nlc4(x):
1019
1009
raw_samples = 16 ,
1020
1010
)
1021
1011
1022
- @mock .patch ("botorch.generation.gen.gen_candidates_torch" )
1023
1012
@mock .patch ("botorch.optim.optimize.gen_batch_initial_conditions" )
1024
1013
@mock .patch ("botorch.optim.optimize.gen_candidates_scipy" )
1025
- @mock .patch ("botorch.optim.utils.common.signature" )
1026
1014
def test_optimize_acqf_non_linear_constraints_sequential (
1027
1015
self ,
1028
- mock_signature ,
1029
1016
mock_gen_candidates_scipy ,
1030
1017
mock_gen_batch_initial_conditions ,
1031
- mock_gen_candidates_torch ,
1032
1018
):
1033
1019
def nlc (x ):
1034
1020
return 4 * x [..., 2 ] - 5
@@ -1037,90 +1023,63 @@ def nlc(x):
1037
1023
num_restarts = 2
1038
1024
raw_samples = 10
1039
1025
options = {}
1040
- for mock_gen_candidates in (
1041
- mock_gen_candidates_torch ,
1042
- mock_gen_candidates_scipy ,
1043
- ):
1044
- if mock_gen_candidates == mock_gen_candidates_torch :
1045
- mock_signature .return_value = signature (gen_candidates_torch )
1046
- else :
1047
- mock_signature .return_value = signature (gen_candidates_scipy )
1048
- for dtype in (torch .float , torch .double ):
1049
- mock_acq_function = MockAcquisitionFunction ()
1050
- mock_gen_batch_initial_conditions .side_effect = [
1051
- torch .zeros (num_restarts , 1 , 3 , device = self .device , dtype = dtype )
1052
- for _ in range (q )
1053
- ]
1054
- gcs_return_vals = [
1055
- (
1056
- torch .tensor (
1057
- [[[1.0 , 2.0 , 3.0 ]]], device = self .device , dtype = dtype
1058
- ),
1059
- torch .tensor ([i ], device = self .device , dtype = dtype ),
1060
- )
1061
- # for nonlinear inequality constraints the batch_limit variable is
1062
- # currently set to 1 by default and hence gen_candidates_scipy is
1063
- # called num_restarts*q times
1064
- for i in range (num_restarts * q )
1065
- ]
1066
- mock_gen_candidates .side_effect = gcs_return_vals
1067
- expected_candidates = torch .cat (
1068
- [cands [0 ] for cands , _ in gcs_return_vals [::num_restarts ]], dim = - 2
1026
+
1027
+ for dtype in (torch .float , torch .double ):
1028
+ mock_acq_function = MockAcquisitionFunction ()
1029
+ mock_gen_batch_initial_conditions .side_effect = [
1030
+ torch .zeros (num_restarts , 1 , 3 , device = self .device , dtype = dtype )
1031
+ for _ in range (q )
1032
+ ]
1033
+ gcs_return_vals = [
1034
+ (
1035
+ torch .tensor ([[[1.0 , 2.0 , 3.0 ]]], device = self .device , dtype = dtype ),
1036
+ torch .tensor ([i ], device = self .device , dtype = dtype ),
1069
1037
)
1070
- bounds = torch .stack (
1071
- [
1072
- torch .zeros (3 , device = self .device , dtype = dtype ),
1073
- 4 * torch .ones (3 , device = self .device , dtype = dtype ),
1074
- ]
1038
+ # for nonlinear inequality constraints the batch_limit variable is
1039
+ # currently set to 1 by default and hence gen_candidates_scipy is
1040
+ # called num_restarts*q times
1041
+ for i in range (num_restarts * q )
1042
+ ]
1043
+ mock_gen_candidates_scipy .side_effect = gcs_return_vals
1044
+ expected_candidates = torch .cat (
1045
+ [cands [0 ] for cands , _ in gcs_return_vals [::num_restarts ]], dim = - 2
1046
+ )
1047
+ bounds = torch .stack (
1048
+ [
1049
+ torch .zeros (3 , device = self .device , dtype = dtype ),
1050
+ 4 * torch .ones (3 , device = self .device , dtype = dtype ),
1051
+ ]
1052
+ )
1053
+ with warnings .catch_warnings (record = True ) as ws :
1054
+ candidates , acq_value = optimize_acqf (
1055
+ acq_function = mock_acq_function ,
1056
+ bounds = bounds ,
1057
+ q = q ,
1058
+ num_restarts = num_restarts ,
1059
+ raw_samples = raw_samples ,
1060
+ options = options ,
1061
+ nonlinear_inequality_constraints = [nlc ],
1062
+ sequential = True ,
1063
+ ic_generator = mock_gen_batch_initial_conditions ,
1064
+ gen_candidates = mock_gen_candidates_scipy ,
1075
1065
)
1076
- with warnings .catch_warnings (record = True ) as ws :
1077
- candidates , acq_value = optimize_acqf (
1078
- acq_function = mock_acq_function ,
1079
- bounds = bounds ,
1080
- q = q ,
1081
- num_restarts = num_restarts ,
1082
- raw_samples = raw_samples ,
1083
- options = options ,
1084
- nonlinear_inequality_constraints = [nlc ],
1085
- sequential = True ,
1086
- ic_generator = mock_gen_batch_initial_conditions ,
1087
- gen_candidates = mock_gen_candidates ,
1088
- )
1089
- if mock_gen_candidates == mock_gen_candidates_torch :
1090
- self .assertEqual (len (ws ), 3 )
1091
- message = (
1092
- "Keyword arguments ['nonlinear_inequality_constraints']"
1093
- " will be ignored because they are not allowed parameters for"
1094
- " function gen_candidates. Allowed parameters are "
1095
- " ['initial_conditions', 'acquisition_function', "
1096
- "'lower_bounds', 'upper_bounds', 'optimizer', 'options',"
1097
- " 'callback', 'fixed_features', 'timeout_sec']."
1098
- )
1099
- expected_warning_raised = (
1100
- issubclass (w .category , UserWarning )
1101
- and message == str (w .message )
1102
- for w in ws
1103
- )
1104
- self .assertTrue (expected_warning_raised )
1105
- # check message
1106
- else :
1107
- self .assertEqual (len (ws ), 0 )
1108
- self .assertTrue (torch .equal (candidates , expected_candidates ))
1109
- # Extract the relevant entries from gcs_return_vals to
1110
- # perform comparison with.
1111
- self .assertTrue (
1112
- torch .equal (
1113
- acq_value ,
1114
- torch .cat (
1115
- [
1116
- expected_acq_value
1117
- for _ , expected_acq_value in gcs_return_vals [
1118
- num_restarts - 1 :: num_restarts
1119
- ]
1066
+ self .assertEqual (len (ws ), 0 )
1067
+ self .assertTrue (torch .equal (candidates , expected_candidates ))
1068
+ # Extract the relevant entries from gcs_return_vals to
1069
+ # perform comparison with.
1070
+ self .assertTrue (
1071
+ torch .equal (
1072
+ acq_value ,
1073
+ torch .cat (
1074
+ [
1075
+ expected_acq_value
1076
+ for _ , expected_acq_value in gcs_return_vals [
1077
+ num_restarts - 1 :: num_restarts
1120
1078
]
1121
- ),
1079
+ ]
1122
1080
),
1123
- )
1081
+ ),
1082
+ )
1124
1083
1125
1084
def test_constraint_caching (self ):
1126
1085
def nlc (x ):
0 commit comments