@@ -468,6 +468,90 @@ def gen_batch_initial_conditions(
468468 return batch_initial_conditions
469469
470470
471+ def gen_optimal_input_initial_conditions (
472+ acq_function : AcquisitionFunction ,
473+ bounds : Tensor ,
474+ q : int ,
475+ num_restarts : int ,
476+ raw_samples : int ,
477+ fixed_features : dict [int , float ] | None = None ,
478+ options : dict [str , bool | float | int ] | None = None ,
479+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
480+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
481+ ):
482+ device = bounds .device
483+ if not hasattr (acq_function , "optimal_inputs" ):
484+ raise AttributeError (
485+ "gen_optimal_input_initial_conditions can only be used with "
486+ "an AcquisitionFunction that has an optimal_inputs attribute."
487+ )
488+ frac_random : float = options .get ("frac_random" , 0.0 )
489+ if not 0 <= frac_random <= 1 :
490+ raise ValueError (
491+ f"frac_random must take on values in (0,1). Value: { frac_random } "
492+ )
493+
494+ batch_limit = options .get ("batch_limit" )
495+ num_optima = acq_function .optimal_inputs .shape [:- 1 ].numel ()
496+ suggestions = acq_function .optimal_inputs .reshape (num_optima , - 1 )
497+ X = torch .empty (0 , q , bounds .shape [1 ], dtype = bounds .dtype )
498+ num_random = round (raw_samples * frac_random )
499+ if num_random > 0 :
500+ X_rnd = sample_q_batches_from_polytope (
501+ n = num_random ,
502+ q = q ,
503+ bounds = bounds ,
504+ n_burnin = options .get ("n_burnin" , 10000 ),
505+ n_thinning = options .get ("n_thinning" , 32 ),
506+ equality_constraints = equality_constraints ,
507+ inequality_constraints = inequality_constraints ,
508+ )
509+ X = torch .cat ((X , X_rnd ))
510+
511+ if num_random < raw_samples :
512+ X_perturbed = sample_points_around_best (
513+ acq_function = acq_function ,
514+ n_discrete_points = q * (raw_samples - num_random ),
515+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
516+ bounds = bounds ,
517+ best_X = suggestions ,
518+ )
519+ X_perturbed = X_perturbed .view (
520+ raw_samples - num_random , q , bounds .shape [- 1 ]
521+ ).cpu ()
522+ X = torch .cat ((X , X_perturbed ))
523+
524+ if options .get ("sample_around_best" , False ):
525+ X_best = sample_points_around_best (
526+ acq_function = acq_function ,
527+ n_discrete_points = q * raw_samples ,
528+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
529+ bounds = bounds ,
530+ )
531+ X_best = X_best .view (raw_samples , q , bounds .shape [- 1 ]).cpu ()
532+ X = torch .cat ((X , X_best ))
533+
534+ with torch .no_grad ():
535+ if batch_limit is None :
536+ batch_limit = X .shape [0 ]
537+ # Evaluate the acquisition function on `X_rnd` using `batch_limit`
538+ # sized chunks.
539+ acq_vals = torch .cat (
540+ [
541+ acq_function (x_ .to (device = device )).cpu ()
542+ for x_ in X .split (split_size = batch_limit , dim = 0 )
543+ ],
544+ dim = 0 ,
545+ )
546+ idx = boltzmann_sample (
547+ function_values = acq_vals ,
548+ num_samples = num_restarts ,
549+ eta = options .get ("eta" , 2.0 ),
550+ )
551+ # set the respective initial conditions to the sampled optimizers
552+ return X [idx ]
553+
554+
471555def gen_one_shot_kg_initial_conditions (
472556 acq_function : qKnowledgeGradient ,
473557 bounds : Tensor ,
@@ -602,59 +686,59 @@ def gen_one_shot_hvkg_initial_conditions(
602686) -> Tensor | None :
603687 r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
604688
605- This function generates initial conditions for optimizing one-shot HVKG using
606- the hypervolume maximizing set (of fixed size) under the posterior mean.
607- Intutively, the hypervolume maximizing set of the fantasized posterior mean
608- will often be close to a hypervolume maximizing set under the current posterior
609- mean. This function uses that fact to generate the initial conditions
610- for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
611- options) of the restarts are generated by learning the hypervolume maximizing sets
612- under the current posterior mean, where each hypervolume maximizing set is
613- obtained from maximizing the hypervolume from a different starting point. Given
614- a hypervolume maximizing set, the `q` candidate points are selected using to the
615- standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
616- hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
617- as well as all `q` candidate points are chosen according to the standard
618- initialization strategy in `gen_batch_initial_conditions`.
619-
620- Args:
621- acq_function: The qKnowledgeGradient instance to be optimized.
622- bounds: A `2 x d` tensor of lower and upper bounds for each column of
623- task features.
624- q: The number of candidates to consider.
625- num_restarts: The number of starting points for multistart acquisition
626- function optimization.
627- raw_samples: The number of raw samples to consider in the initialization
628- heuristic.
629- fixed_features: A map `{feature_index: value}` for features that
630- should be fixed to a particular value during generation.
631- options: Options for initial condition generation. These contain all
632- settings for the standard heuristic initialization from
633- `gen_batch_initial_conditions`. In addition, they contain
634- `frac_random` (the fraction of fully random fantasy points),
635- `num_inner_restarts` and `raw_inner_samples` (the number of random
636- restarts and raw samples for solving the posterior objective
637- maximization problem, respectively) and `eta` (temperature parameter
638- for sampling heuristic from posterior objective maximizers).
639- inequality constraints: A list of tuples (indices, coefficients, rhs),
640- with each tuple encoding an inequality constraint of the form
641- `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
642- equality constraints: A list of tuples (indices, coefficients, rhs),
643- with each tuple encoding an inequality constraint of the form
644- `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
645-
646- Returns:
647- A `num_restarts x q' x d` tensor that can be used as initial conditions
648- for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
649- of points (candidate points plus fantasy points).
650-
651- Example:
652- >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
653- >>> bounds = torch.tensor([[0., 0.], [1., 1.]])
654- >>> Xinit = gen_one_shot_hvkg_initial_conditions(
655- >>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
656- >>> options={"frac_random": 0.25},
657- >>> )
689+ This function generates initial conditions for optimizing one-shot HVKG using
690+ the hypervolume maximizing set (of fixed size) under the posterior mean.
691+ Intutively, the hypervolume maximizing set of the fantasized posterior mean
692+ will often be close to a hypervolume maximizing set under the current posterior
693+ mean. This function uses that fact to generate the initial conditions
694+ for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
695+ options) of the restarts are generated by learning the hypervolume maximizing sets
696+ under the current posterior mean, where each hypervolume maximizing set is
697+ obtained from maximizing the hypervolume from a different starting point. Given
698+ a hypervolume maximizing set, the `q` candidate points are selected using to the
699+ standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
700+ hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
701+ as well as all `q` candidate points are chosen according to the standard
702+ initialization strategy in `gen_batch_initial_conditions`.
703+
704+ Args:
705+ acq_function: The qKnowledgeGradient instance to be optimized.
706+ bounds: A `2 x d` tensor of lower and upper bounds for each column of
707+ task features.
708+ q: The number of candidates to consider.
709+ num_restarts: The number of starting points for multistart acquisition
710+ function optimization.
711+ raw_samples: The number of raw samples to consider in the initialization
712+ heuristic.
713+ fixed_features: A map `{feature_index: value}` for features that
714+ should be fixed to a particular value during generation.
715+ options: Options for initial condition generation. These contain all
716+ settings for the standard heuristic initialization from
717+ `gen_batch_initial_conditions`. In addition, they contain
718+ `frac_random` (the fraction of fully random fantasy points),
719+ `num_inner_restarts` and `raw_inner_samples` (the number of random
720+ restarts and raw samples for solving the posterior objective
721+ maximization problem, respectively) and `eta` (temperature parameter
722+ for sampling heuristic from posterior objective maximizers).
723+ inequality constraints: A list of tuples (indices, coefficients, rhs),
724+ with each tuple encoding an inequality constraint of the form
725+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
726+ equality constraints: A list of tuples (indices, coefficients, rhs),
727+ with each tuple encoding an inequality constraint of the form
728+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
729+
730+ Returns:
731+ A `num_restarts x q' x d` tensor that can be used as initial conditions
732+ for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
733+ of points (candidate points plus fantasy points).
734+
735+ gen_batch_initial_conditions Example:
736+ >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
737+ >>> bounds = torch.tensor([[0., 0.], [1., 1.]])
738+ >>> Xinit = gen_one_shot_hvkg_initial_conditions(
739+ >>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
740+ >>> options={"frac_random": 0.25},
741+ >>> )
658742 """
659743 from botorch .optim .optimize import optimize_acqf
660744
@@ -1136,6 +1220,7 @@ def sample_points_around_best(
11361220 best_pct : float = 5.0 ,
11371221 subset_sigma : float = 1e-1 ,
11381222 prob_perturb : float | None = None ,
1223+ best_X : Tensor | None = None ,
11391224) -> Tensor | None :
11401225 r"""Find best points and sample nearby points.
11411226
@@ -1154,60 +1239,62 @@ def sample_points_around_best(
11541239 An optional `n_discrete_points x d`-dim tensor containing the
11551240 sampled points. This is None if no baseline points are found.
11561241 """
1157- X = get_X_baseline (acq_function = acq_function )
1158- if X is None :
1159- return
1160- with torch .no_grad ():
1161- try :
1162- posterior = acq_function .model .posterior (X )
1163- except AttributeError :
1164- warnings .warn (
1165- "Failed to sample around previous best points." ,
1166- BotorchWarning ,
1167- stacklevel = 3 ,
1168- )
1242+ if best_X is None :
1243+ X = get_X_baseline (acq_function = acq_function )
1244+ if X is None :
11691245 return
1170- mean = posterior .mean
1171- while mean .ndim > 2 :
1172- # take average over batch dims
1173- mean = mean .mean (dim = 0 )
1174- try :
1175- f_pred = acq_function .objective (mean )
1176- # Some acquisition functions do not have an objective
1177- # and for some acquisition functions the objective is None
1178- except (AttributeError , TypeError ):
1179- f_pred = mean
1180- if hasattr (acq_function , "maximize" ):
1181- # make sure that the optimiztaion direction is set properly
1182- if not acq_function .maximize :
1183- f_pred = - f_pred
1184- try :
1185- # handle constraints for EHVI-based acquisition functions
1186- constraints = acq_function .constraints
1187- if constraints is not None :
1188- neg_violation = - torch .stack (
1189- [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1190- ).sum (dim = - 1 )
1191- feas = neg_violation == 0
1192- if feas .any ():
1193- f_pred [~ feas ] = float ("-inf" )
1194- else :
1195- # set objective equal to negative violation
1196- f_pred = neg_violation
1197- except AttributeError :
1198- pass
1199- if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1200- # multi-objective
1201- # find pareto set
1202- is_pareto = is_non_dominated (f_pred )
1203- best_X = X [is_pareto ]
1204- else :
1205- if f_pred .shape [- 1 ] == 1 :
1206- f_pred = f_pred .squeeze (- 1 )
1207- n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1208- # the view() is to ensure that best_idcs is not a scalar tensor
1209- best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1210- best_X = X [best_idcs ]
1246+ with torch .no_grad ():
1247+ try :
1248+ posterior = acq_function .model .posterior (X )
1249+ except AttributeError :
1250+ warnings .warn (
1251+ "Failed to sample around previous best points." ,
1252+ BotorchWarning ,
1253+ stacklevel = 3 ,
1254+ )
1255+ return
1256+ mean = posterior .mean
1257+ while mean .ndim > 2 :
1258+ # take average over batch dims
1259+ mean = mean .mean (dim = 0 )
1260+ try :
1261+ f_pred = acq_function .objective (mean )
1262+ # Some acquisition functions do not have an objective
1263+ # and for some acquisition functions the objective is None
1264+ except (AttributeError , TypeError ):
1265+ f_pred = mean
1266+ if hasattr (acq_function , "maximize" ):
1267+ # make sure that the optimiztaion direction is set properly
1268+ if not acq_function .maximize :
1269+ f_pred = - f_pred
1270+ try :
1271+ # handle constraints for EHVI-based acquisition functions
1272+ constraints = acq_function .constraints
1273+ if constraints is not None :
1274+ neg_violation = - torch .stack (
1275+ [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1276+ ).sum (dim = - 1 )
1277+ feas = neg_violation == 0
1278+ if feas .any ():
1279+ f_pred [~ feas ] = float ("-inf" )
1280+ else :
1281+ # set objective equal to negative violation
1282+ f_pred = neg_violation
1283+ except AttributeError :
1284+ pass
1285+ if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1286+ # multi-objective
1287+ # find pareto set
1288+ is_pareto = is_non_dominated (f_pred )
1289+ best_X = X [is_pareto ]
1290+ else :
1291+ if f_pred .shape [- 1 ] == 1 :
1292+ f_pred = f_pred .squeeze (- 1 )
1293+ n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1294+ # the view() is to ensure that best_idcs is not a scalar tensor
1295+ best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1296+ best_X = X [best_idcs ]
1297+
12111298 use_perturbed_sampling = best_X .shape [- 1 ] >= 20 or prob_perturb is not None
12121299 n_trunc_normal_points = (
12131300 n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
0 commit comments