Skip to content

Commit c6a4a2a

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Move HeterogeneousMTGP to OSS BoTorch (#3073)
Summary: Pull Request resolved: #3073 Heterogeneous MTGP model as introduced in .. [Deshwal2024Heterogeneous] A. Deshwal, S. Cakmak., Y. Xia, and D. Eriksson. Sample-Efficient Bayesian Optimization with Transfer Learning for Heterogeneous Search Spaces. AutoML Conference, 2024. Moving to OSS, so that we can dispatch to it in Ax modular BoTorch generator. Reviewed By: sdaulton Differential Revision: D86422636 fbshipit-source-id: 75095f0e2bb353533f239e49d143f59ae6bf3740
1 parent 40bb460 commit c6a4a2a

File tree

5 files changed

+1004
-0
lines changed

5 files changed

+1004
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
r"""
10+
Multi-Task GP model designed to operate on tasks from different search spaces.
11+
12+
References:
13+
14+
.. [Deshwal2024Heterogeneous]
15+
A. Deshwal, S. Cakmak., Y. Xia, and D. Eriksson.
16+
Sample-Efficient Bayesian Optimization with Transfer Learning for
17+
Heterogeneous Search Spaces. AutoML Conference, 2024.
18+
"""
19+
20+
from itertools import chain
21+
from typing import Any
22+
23+
import torch
24+
from botorch.acquisition.objective import PosteriorTransform
25+
from botorch.exceptions.errors import UnsupportedError
26+
from botorch.models.kernels.heterogeneous_multitask import MultiTaskConditionalKernel
27+
from botorch.models.multitask import MultiTaskGP
28+
from botorch.models.transforms.input import InputTransform
29+
from botorch.models.transforms.outcome import OutcomeTransform
30+
from botorch.models.utils.gpytorch_modules import (
31+
get_gaussian_likelihood_with_gamma_prior,
32+
)
33+
from botorch.posteriors.gpytorch import GPyTorchPosterior
34+
from botorch.posteriors.transformed import TransformedPosterior
35+
from botorch.utils.datasets import MultiTaskDataset
36+
from torch import Tensor
37+
38+
39+
class HeterogeneousMTGP(MultiTaskGP):
40+
"""A multi-task GP model designed to operate on tasks from
41+
different search spaces. This model uses `MultiTaskConditionalKernel`.
42+
43+
This model was introduced in [Deshwal2024Heterogeneous]_.
44+
45+
* The model is designed to work with a `MultiTaskDataset` that contains
46+
datasets with different features.
47+
* It uses a helper to embed the `X` coming from the sub-spaces into the
48+
full-feature space (+ task feature) before passing them down to the
49+
base `MultiTaskGP`.
50+
* The same helper is used in the `posterior` method to embed the `X` from
51+
the target task into the full dimensional space before evaluating the
52+
`posterior` method of the base class.
53+
* This model also overwrites the `_split_inputs` method. Instead of
54+
`x_basic`, we return the `X` with task feature included since this is
55+
used by the `MultiTaskConditionalKernel` to identify the active
56+
dimensions of / the kernels to evaluate for the given input.
57+
"""
58+
59+
def __init__(
60+
self,
61+
train_Xs: list[Tensor],
62+
train_Ys: list[Tensor],
63+
train_Yvars: list[Tensor] | None,
64+
feature_indices: list[list[int]],
65+
full_feature_dim: int,
66+
rank: int | None = None,
67+
use_saas_prior: bool = True,
68+
use_combinatorial_kernel: bool = True,
69+
all_tasks: list[int] | None = None,
70+
input_transform: InputTransform | None = None,
71+
outcome_transform: OutcomeTransform | None = None,
72+
validate_task_values: bool = True,
73+
) -> None:
74+
"""Construct a heterogeneous multi-task GP model from lists of inputs
75+
corresponding to each task.
76+
77+
NOTE: This model assumes that the task 0 is the output / target task.
78+
It will only produce predictions for task 0.
79+
80+
Args:
81+
train_Xs: A list of tensors of shape `(n_i x d_i)` where `d_i` is the
82+
dimensionality of the input features for task i.
83+
NOTE: These should not include the task feature!
84+
train_Ys: A list of tensors of shape `(n_i x 1)` containing the
85+
observations for the corresponding task.
86+
train_Yvars: An optional list of tensors of shape `(n_i x 1)` containing
87+
the observation variances for the corresponding task.
88+
feature_indices: A list of lists of integers specifying the indices
89+
mapping the features from a given task to the full tensor of features.
90+
The `i`th element of the list should contain `d_i` integers.
91+
full_feature_dim: The total number of features across all tasks. This
92+
does not include the task feature dimension.
93+
rank: The rank of the cross-task covariance matrix.
94+
use_saas_prior: Whether to use the SAAS prior for base kernels of the
95+
`MultiTaskConditionalKernel`.
96+
use_combinatorial_kernel: Whether to use a combinatorial kernel over the
97+
binary embedding of task features in `MultiTaskConditionalKernel`.
98+
all_tasks: By default, multi-task GPs infer the list of all tasks from
99+
the task features in `train_X`. This is an experimental feature that
100+
enables creation of multi-task GPs with tasks that don't appear in the
101+
training data. Note that when a task is not observed, the corresponding
102+
task covariance will heavily depend on random initialization and may
103+
behave unexpectedly.
104+
input_transform: An input transform that is applied in the model's
105+
forward pass. The transform should be compatible with the inputs
106+
from the full feature space with the task feature appended.
107+
outcome_transform: An outcome transform that is applied to the
108+
training data during instantiation and to the posterior during
109+
inference (that is, the `Posterior` obtained by calling
110+
`.posterior` on the model will be on the original scale).
111+
validate_task_values: If True, validate that the task values supplied in the
112+
input are expected tasks values. If false, unexpected task values
113+
will be mapped to the first output_task if supplied.
114+
"""
115+
self.full_feature_dim = full_feature_dim
116+
self.feature_indices = feature_indices
117+
full_X = torch.cat(
118+
[self.map_to_full_tensor(X=X, task_index=i) for i, X in enumerate(train_Xs)]
119+
)
120+
full_Y = torch.cat(train_Ys)
121+
full_Yvar = None if train_Yvars is None else torch.cat(train_Yvars)
122+
covar_module = MultiTaskConditionalKernel(
123+
feature_indices=feature_indices,
124+
use_saas_prior=use_saas_prior,
125+
use_combinatorial_kernel=use_combinatorial_kernel,
126+
)
127+
# The features that are forward passed through the kernel should include
128+
# the task dim
129+
covar_module.active_dims = torch.arange(full_feature_dim + 1)
130+
likelihood = (
131+
None # Constructed in MultiTaskGP.
132+
if full_Yvar is not None
133+
else get_gaussian_likelihood_with_gamma_prior()
134+
)
135+
super().__init__(
136+
train_X=full_X,
137+
train_Y=full_Y,
138+
task_feature=-1,
139+
train_Yvar=full_Yvar,
140+
mean_module=None,
141+
covar_module=covar_module,
142+
likelihood=likelihood,
143+
output_tasks=[0],
144+
rank=rank,
145+
all_tasks=all_tasks,
146+
input_transform=input_transform,
147+
outcome_transform=outcome_transform,
148+
validate_task_values=validate_task_values,
149+
)
150+
151+
@classmethod
152+
def get_all_tasks(
153+
cls,
154+
train_X: Tensor,
155+
task_feature: int,
156+
output_tasks: list[int] | None = None,
157+
) -> tuple[list[int], int, int]:
158+
(
159+
all_tasks_inferred,
160+
task_feature,
161+
num_non_task_features,
162+
) = super().get_all_tasks(
163+
train_X=train_X, task_feature=task_feature, output_tasks=output_tasks
164+
)
165+
if 0 not in all_tasks_inferred:
166+
all_tasks_inferred = [0] + all_tasks_inferred
167+
return all_tasks_inferred, task_feature, num_non_task_features
168+
169+
def map_to_full_tensor(self, X: Tensor, task_index: int) -> Tensor:
170+
"""Map a tensor of task-specific features to the full tensor of features,
171+
utilizing the feature indices to map each feature to its corresponding
172+
position in the full tensor. Also append the task index as the last column.
173+
The columns of the full tensor that are not used by the given task will be
174+
filled with zeros.
175+
176+
Args:
177+
X: A tensor of shape `(n x d_i)` where `d_i` is the number of features
178+
in the original task dataset.
179+
task_index: The index of the task whose features are being mapped.
180+
181+
Returns:
182+
A tensor of shape `(n x (self.full_feature_dim + 1))` containing the
183+
mapped features.
184+
185+
Example:
186+
>>> # Suppose full feature dim is 3 and the feature indices for
187+
>>> # task 5 are [2, 0].
188+
>>> X = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
189+
>>> X_full = self.map_to_full_tensor(X=X, task_index=5)
190+
>>> # X_full = torch.tensor([[2.0, 0.0, 1.0, 5.0], [4.0, 0.0, 3.0, 5.0]])
191+
"""
192+
X_full = torch.zeros(
193+
*X.shape[:-1], self.full_feature_dim + 1, dtype=X.dtype, device=X.device
194+
)
195+
X_full[..., self.feature_indices[task_index]] = X
196+
X_full[..., -1] = task_index
197+
return X_full
198+
199+
def posterior(
200+
self,
201+
X: Tensor,
202+
output_indices: list[int] | None = None,
203+
observation_noise: bool | Tensor = False,
204+
posterior_transform: PosteriorTransform | None = None,
205+
**kwargs: Any,
206+
) -> GPyTorchPosterior | TransformedPosterior:
207+
r"""Computes the posterior for the target task at the provided points.
208+
209+
Args:
210+
X: A tensor of shape `batch_shape x q x d_0(+1)`, where `d_0` is the
211+
dimension of the feature space for task 0 (not including task indices)
212+
and `q` is the number of points considered jointly.
213+
output_indices: Not supported. Must be `None` or `[0]`. The model will
214+
only produce predictions for the target task regardless of
215+
the value of `output_indices`.
216+
observation_noise: If True, add observation noise from the respective
217+
likelihoods. If a Tensor, specifies the observation noise levels
218+
to add.
219+
posterior_transform: An optional PosteriorTransform.
220+
221+
Returns:
222+
A `GPyTorchPosterior` object, representing `batch_shape` joint
223+
distributions over `q` points.
224+
"""
225+
if output_indices is not None and output_indices != [0]:
226+
raise UnsupportedError(
227+
"Heterogeneous MTGP does not support `output_indices`. "
228+
)
229+
if X.shape[-1] == len(self.feature_indices[0]) + 1:
230+
# X contains task feature
231+
if (X[..., -1] != 0).any():
232+
raise UnsupportedError(
233+
"Posterior can only be called for the target task."
234+
)
235+
X = X[..., :-1]
236+
X_full = self.map_to_full_tensor(X=X, task_index=0)
237+
return super().posterior(
238+
X=X_full,
239+
observation_noise=observation_noise,
240+
posterior_transform=posterior_transform,
241+
**kwargs,
242+
)
243+
244+
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
245+
r"""Returns x itself along with a tensor containing the task indices only.
246+
247+
NOTE: This differs from the base class implementation because it returns
248+
the full tensor in place of `x_basic`. This is because the multi-task
249+
conditional kernel utilized the task feature for conditioning.
250+
251+
Args:
252+
x: The full input tensor with trailing dimension of size
253+
`self.full_feature_dim + 1 + 1`.
254+
255+
Returns:
256+
3-element tuple containing
257+
- The original tensor `x`.
258+
- A tensor of long data type containing the task indices.
259+
- A tensor with d=0. split_inputs by default returns X_before_index,
260+
task_indices, X_after_index, and so thus has to return a 3-tuple.
261+
"""
262+
task_idcs = x[..., self._task_feature : self._task_feature + 1].to(
263+
dtype=torch.long
264+
)
265+
return x, task_idcs, torch.zeros(x.shape[:-1] + (0,)).to(x)
266+
267+
@classmethod
268+
# pyre-ignore [14] Inconsistent override is expected.
269+
def construct_inputs(
270+
cls,
271+
training_data: MultiTaskDataset,
272+
task_feature: int = -1,
273+
output_tasks: list[int] | None = None,
274+
rank: int | None = None,
275+
use_saas_prior: bool = True,
276+
use_combinatorial_kernel: bool = True,
277+
) -> dict[str, Any]:
278+
r"""Construct `Model` keyword arguments from a given `MultiTaskDataset`.
279+
280+
Args:
281+
training_data: A `MultiTaskDataset`.
282+
task_feature: Column index of embedded task indicator features.
283+
Only supported value is `-1`.
284+
output_tasks: A list of task indices for which to compute model
285+
outputs for. Only supported value is `[0]`.
286+
rank: The rank of the cross-task covariance matrix.
287+
use_saas_prior: Whether to use the SAAS prior for base kernels of the
288+
`MultiTaskConditionalKernel`.
289+
use_combinatorial_kernel: Whether to use a combinatorial kernel over the
290+
binary embedding of task features in `MultiTaskConditionalKernel`.
291+
"""
292+
if training_data.task_feature_index != -1:
293+
raise NotImplementedError(
294+
"Heterogeneous MTGP requires `task_feature_index` to be -1."
295+
)
296+
if task_feature != -1:
297+
raise NotImplementedError("Heterogeneous MTGP requires `task_feature=-1`.")
298+
if output_tasks is not None and output_tasks != [0]:
299+
raise NotImplementedError(
300+
"Heterogeneous MTGP currently only supports output_tasks=[0]. "
301+
"The target task will be given the task value of 0."
302+
)
303+
child_datasets = training_data.datasets.copy()
304+
target_dataset = child_datasets.pop(training_data.target_outcome_name)
305+
all_datasets = [target_dataset] + list(child_datasets.values())
306+
# We want all parameters to be in the same order, and include the full X.
307+
# remove task feature
308+
all_features = sorted(
309+
set(chain(*(ds.feature_names[:-1] for ds in all_datasets)))
310+
)
311+
# Get indices mapping the features from a given dataset to all features.
312+
feature_indices = [
313+
[all_features.index(fn) for fn in ds.feature_names[:-1]]
314+
for ds in all_datasets
315+
]
316+
Xs = [ds.X[..., :-1] for ds in all_datasets]
317+
Ys = [ds.Y for ds in all_datasets]
318+
Yvars = (
319+
None if target_dataset.Yvar is None else [ds.Yvar for ds in all_datasets]
320+
)
321+
all_tasks = list(range(len(all_datasets)))
322+
return {
323+
"train_Xs": Xs,
324+
"train_Ys": Ys,
325+
"train_Yvars": Yvars,
326+
"feature_indices": feature_indices,
327+
"full_feature_dim": len(all_features),
328+
"rank": rank,
329+
"use_saas_prior": use_saas_prior,
330+
"use_combinatorial_kernel": use_combinatorial_kernel,
331+
"all_tasks": all_tasks,
332+
}

0 commit comments

Comments
 (0)