14
14
import warnings
15
15
from collections .abc import Callable
16
16
from functools import partial
17
- from typing import Any , NoReturn
17
+ from typing import Any , Mapping , NoReturn
18
18
19
19
import numpy as np
20
20
import numpy .typing as npt
@@ -64,7 +64,7 @@ def gen_candidates_scipy(
64
64
equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
65
65
nonlinear_inequality_constraints : list [tuple [Callable , bool ]] | None = None ,
66
66
options : dict [str , Any ] | None = None ,
67
- fixed_features : dict [int , float | None ] | None = None ,
67
+ fixed_features : Mapping [int , float | Tensor ] | None = None ,
68
68
timeout_sec : float | None = None ,
69
69
use_parallel_mode : bool | None = None ,
70
70
) -> tuple [Tensor , Tensor ]:
@@ -107,11 +107,11 @@ def gen_candidates_scipy(
107
107
and SLSQP if inequality or equality constraints are present. If
108
108
`with_grad=False`, then we use a two-point finite difference estimate
109
109
of the gradient.
110
- fixed_features: This is a dictionary of feature indices to values, where
110
+ fixed_features: Mapping[int, float | Tensor] | None,
111
111
all generated candidates will have features fixed to these values.
112
- If the dictionary value is None, then that feature will just be
113
- fixed to the clamped value and not optimized. Assumes values to be
114
- compatible with lower_bounds and upper_bounds!
112
+ If passing tensors as values, they should have either shape `b` or
113
+ `b x q` to fix the same feature to different values in the batch.
114
+ Assumes values to be compatible with lower_bounds and upper_bounds!
115
115
timeout_sec: Timeout (in seconds) for `scipy.optimize.minimize` routine -
116
116
if provided, optimization will stop after this many seconds and return
117
117
the best solution found so far.
@@ -211,18 +211,17 @@ def f(x):
211
211
timeout_sec = timeout_sec ,
212
212
)
213
213
214
+ f_np_wrapper = _get_f_np_wrapper (
215
+ clamped_candidates .shape ,
216
+ initial_conditions .device ,
217
+ initial_conditions .dtype ,
218
+ with_grad ,
219
+ )
220
+
214
221
if not why_not_fast_path and use_parallel_mode is not False :
215
222
if is_constrained :
216
223
raise RuntimeWarning ("Method L-BFGS-B cannot handle constraints." )
217
224
218
- f_np_wrapper = _get_f_np_wrapper (
219
- clamped_candidates .shape ,
220
- initial_conditions .device ,
221
- initial_conditions .dtype ,
222
- with_grad ,
223
- batched = True ,
224
- )
225
-
226
225
batched_x0 = _arrayify (clamped_candidates ).reshape (len (clamped_candidates ), - 1 )
227
226
228
227
l_bfgs_b_bounds = translate_bounds_for_lbfgsb (
@@ -242,6 +241,7 @@ def f(x):
242
241
bounds = l_bfgs_b_bounds ,
243
242
# constraints=constraints,
244
243
callback = options .get ("callback" , None ),
244
+ pass_batch_indices = True ,
245
245
** minimize_options ,
246
246
)
247
247
for res in results :
@@ -264,21 +264,38 @@ def f(x):
264
264
else :
265
265
logger .debug (msg )
266
266
267
- f_np_wrapper = _get_f_np_wrapper (
268
- clamped_candidates .shape ,
269
- initial_conditions .device ,
270
- initial_conditions .dtype ,
271
- with_grad ,
272
- )
267
+ if (
268
+ fixed_features
269
+ and any (
270
+ torch .is_tensor (ff ) and ff .ndim > 0 for ff in fixed_features .values ()
271
+ )
272
+ and max_optimization_problem_aggregation_size != 1
273
+ ):
274
+ raise UnsupportedError (
275
+ "Batch shaped fixed features are not "
276
+ "supported, when optimizing more than one optimization "
277
+ "problem at a time."
278
+ )
273
279
274
280
all_xs = []
275
281
split_candidates = clamped_candidates .split (
276
282
max_optimization_problem_aggregation_size
277
283
)
278
- for candidates_ in split_candidates :
279
- # We optimize the candidates at hand as a single problem
284
+ for i , candidates_ in enumerate (split_candidates ):
285
+ if fixed_features :
286
+ fixed_features_ = {
287
+ k : ff [i : i + 1 ].item ()
288
+ # from the test above, we know that we only treat one candidate
289
+ # at a time thus we can use index i
290
+ if torch .is_tensor (ff ) and ff .ndim > 0
291
+ else ff
292
+ for k , ff in fixed_features .items ()
293
+ }
294
+ else :
295
+ fixed_features_ = None
296
+
280
297
_no_fixed_features = _remove_fixed_features_from_optimization (
281
- fixed_features = fixed_features ,
298
+ fixed_features = fixed_features_ ,
282
299
acquisition_function = acquisition_function ,
283
300
initial_conditions = None ,
284
301
d = initial_conditions_all_features .shape [- 1 ],
@@ -296,7 +313,7 @@ def f(x):
296
313
297
314
f_np_wrapper_ = partial (
298
315
f_np_wrapper ,
299
- fixed_features = fixed_features ,
316
+ fixed_features = fixed_features_ ,
300
317
)
301
318
302
319
x0 = candidates_ .flatten ()
@@ -363,13 +380,14 @@ def f(x):
363
380
return clamped_candidates , batch_acquisition
364
381
365
382
366
- def _get_f_np_wrapper (shapeX , device , dtype , with_grad , batched = False ):
383
+ def _get_f_np_wrapper (shapeX , device , dtype , with_grad ):
367
384
if with_grad :
368
385
369
386
def f_np_wrapper (
370
387
x : npt .NDArray ,
371
388
f : Callable ,
372
- fixed_features : dict [int , float ] | None ,
389
+ fixed_features : Mapping [int , float | Tensor ] | None ,
390
+ batch_indices : list [int ] | None = None ,
373
391
) -> tuple [float | np .NDArray , np .NDArray ]:
374
392
"""Given a torch callable, compute value + grad given a numpy array."""
375
393
if np .isnan (x ).any ():
@@ -387,8 +405,21 @@ def f_np_wrapper(
387
405
.contiguous ()
388
406
.requires_grad_ (True )
389
407
)
408
+ if fixed_features is not None :
409
+ if batch_indices is not None :
410
+ this_fixed_features = {
411
+ k : ff [batch_indices ]
412
+ if torch .is_tensor (ff ) and ff .ndim > 0
413
+ else ff
414
+ for k , ff in fixed_features .items ()
415
+ }
416
+ else :
417
+ this_fixed_features = fixed_features
418
+ else :
419
+ this_fixed_features = None
420
+
390
421
X_fix = fix_features (
391
- X , fixed_features = fixed_features , replace_current_value = False
422
+ X , fixed_features = this_fixed_features , replace_current_value = False
392
423
)
393
424
# we compute the loss on the whole batch, under the assumption that f
394
425
# treats multiple inputs in the 0th dimension as independent
@@ -409,7 +440,7 @@ def f_np_wrapper(
409
440
raise OptimizationGradientError (msg , current_x = x )
410
441
fval = (
411
442
losses .detach ().view (- 1 ).cpu ().numpy ()
412
- if batched
443
+ if batch_indices is not None
413
444
else loss .detach ().item ()
414
445
) # the view(-1) seems necessary as f might return a single scalar
415
446
return fval , gradf
@@ -485,7 +516,7 @@ def gen_candidates_torch(
485
516
optimizer : type [Optimizer ] = torch .optim .Adam ,
486
517
options : dict [str , float | str ] | None = None ,
487
518
callback : Callable [[int , Tensor , Tensor ], NoReturn ] | None = None ,
488
- fixed_features : dict [int , float | None ] | None = None ,
519
+ fixed_features : Mapping [int , float | Tensor ] | None = None ,
489
520
timeout_sec : float | None = None ,
490
521
) -> tuple [Tensor , Tensor ]:
491
522
r"""Generate a set of candidates using a `torch.optim` optimizer.
@@ -507,9 +538,10 @@ def gen_candidates_torch(
507
538
the loss and gradients, but before calling the optimizer.
508
539
fixed_features: This is a dictionary of feature indices to values, where
509
540
all generated candidates will have features fixed to these values.
510
- If the dictionary value is None, then that feature will just be
511
- fixed to the clamped value and not optimized. Assumes values to be
512
- compatible with lower_bounds and upper_bounds!
541
+ If a float is passed it is fixed across [b,q], if a tensor is passed:
542
+ it might either be of shape [b,q] or [b], in which case the same value
543
+ is used across the q dimension.
544
+ Assumes values to be compatible with lower_bounds and upper_bounds!
513
545
timeout_sec: Timeout (in seconds) for optimization. If provided,
514
546
`gen_candidates_torch` will stop after this many seconds and return
515
547
the best solution found so far.
@@ -533,46 +565,21 @@ def gen_candidates_torch(
533
565
upper_bounds=bounds[1],
534
566
)
535
567
"""
536
- assert not fixed_features or not any (
537
- torch .is_tensor (v ) for v in fixed_features .values ()
538
- ), "`gen_candidates_torch` does not support tensor-valued fixed features."
539
-
540
568
start_time = time .monotonic ()
541
569
options = options or {}
542
-
543
- # if there are fixed features we may optimize over a domain of lower dimension
544
- if fixed_features :
545
- subproblem = _remove_fixed_features_from_optimization (
546
- fixed_features = fixed_features ,
547
- acquisition_function = acquisition_function ,
548
- initial_conditions = initial_conditions ,
549
- d = initial_conditions .shape [- 1 ],
550
- lower_bounds = lower_bounds ,
551
- upper_bounds = upper_bounds ,
552
- inequality_constraints = None ,
553
- equality_constraints = None ,
554
- nonlinear_inequality_constraints = None ,
555
- )
556
-
557
- # call the routine with no fixed_features
558
- elapsed = time .monotonic () - start_time
559
- clamped_candidates , batch_acquisition = gen_candidates_torch (
560
- initial_conditions = subproblem .initial_conditions ,
561
- acquisition_function = subproblem .acquisition_function ,
562
- lower_bounds = subproblem .lower_bounds ,
563
- upper_bounds = subproblem .upper_bounds ,
564
- optimizer = optimizer ,
565
- options = options ,
566
- callback = callback ,
567
- fixed_features = None ,
568
- timeout_sec = timeout_sec - elapsed if timeout_sec else None ,
569
- )
570
- clamped_candidates = subproblem .acquisition_function ._construct_X_full (
571
- clamped_candidates
572
- )
573
- return clamped_candidates , batch_acquisition
570
+ # We remove max_optimization_problem_aggregation_size as it does not affect
571
+ # the 1st order optimizers implemented in this method.
572
+ # Here, it does not matter whether one combines multiple optimizations into
573
+ # one or not.
574
+ options .pop ("max_optimization_problem_aggregation_size" , None )
574
575
_clamp = partial (columnwise_clamp , lower = lower_bounds , upper = upper_bounds )
575
- clamped_candidates = _clamp (initial_conditions ).requires_grad_ (True )
576
+ clamped_candidates = _clamp (initial_conditions )
577
+ if fixed_features :
578
+ clamped_candidates = clamped_candidates [
579
+ ...,
580
+ [i for i in range (clamped_candidates .shape [- 1 ]) if i not in fixed_features ],
581
+ ]
582
+ clamped_candidates = clamped_candidates .requires_grad_ (True )
576
583
_optimizer = optimizer (params = [clamped_candidates ], lr = options .get ("lr" , 0.025 ))
577
584
578
585
i = 0
@@ -583,7 +590,7 @@ def gen_candidates_torch(
583
590
with torch .no_grad ():
584
591
X = _clamp (clamped_candidates ).requires_grad_ (True )
585
592
586
- loss = - acquisition_function (X ).sum ()
593
+ loss = - acquisition_function (fix_features ( X , fixed_features ) ).sum ()
587
594
grad = torch .autograd .grad (loss , X )[0 ]
588
595
if callback :
589
596
callback (i , loss , grad )
@@ -602,6 +609,7 @@ def assign_grad():
602
609
logger .info (f"Optimization timed out after { runtime } seconds." )
603
610
604
611
clamped_candidates = _clamp (clamped_candidates )
612
+ clamped_candidates = fix_features (clamped_candidates , fixed_features )
605
613
with torch .no_grad ():
606
614
batch_acquisition = acquisition_function (clamped_candidates )
607
615
0 commit comments