Skip to content

Commit 60be953

Browse files
mrlj-hashkayween
andauthored
Fix the shape bug in RQKernel when using deep GPs
There was a shape mismatch bug when using the rational quadratic kernel in deep GPs; see #2674. The root cause is that the `alpha` parameter was not unsqueezed into a correct shape. * Unsqueeze `alpha * Revise comments for fix to RQKernel --------- Co-authored-by: Kaiwen Wu <37524685+kayween@users.noreply.github.com> Co-authored-by: Kaiwen Wu <kwwu2015@gmail.com>
1 parent d6f02cc commit 60be953

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

gpytorch/kernels/rq_kernel.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,22 @@ def __init__(self, alpha_constraint: Optional[Interval] = None, **kwargs):
6060

6161
self.register_constraint("raw_alpha", alpha_constraint)
6262

63-
def forward(self, x1, x2, diag=False, **params):
63+
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
6464
def postprocess_rq(dist_mat):
6565
alpha = self.alpha
66-
for _ in range(1, len(dist_mat.shape) - len(self.batch_shape)):
66+
67+
if not diag:
68+
alpha = alpha.unsqueeze(-1)
69+
70+
if last_dim_is_batch:
6771
alpha = alpha.unsqueeze(-1)
72+
6873
return (1 + dist_mat.div(2 * alpha)).pow(-alpha)
6974

7075
x1_ = x1.div(self.lengthscale)
7176
x2_ = x2.div(self.lengthscale)
7277
return postprocess_rq(
73-
self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **params),
78+
self.covar_dist(x1_, x2_, square_dist=True, diag=diag, last_dim_is_batch=last_dim_is_batch, **params),
7479
)
7580

7681
@property

test/kernels/test_rq_kernel.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,37 @@ def test_initialize_alpha(self):
222222
actual_value = torch.tensor(3.0).view_as(kernel.alpha)
223223
self.assertLess(torch.norm(kernel.alpha - actual_value), 1e-5)
224224

225+
def test_extra_batch_dim(self):
226+
a = torch.linspace(0.0, 1.0, 120).view(5, 4, 3, 2)
227+
b = torch.linspace(0.0, 1.0, 80).view(5, 4, 2, 2)
228+
229+
lengthscale = 0.1
230+
alpha = torch.linspace(0.5, 1.0, 4) # different `alpha` for each batch
231+
232+
# NOTE: Why do we pass down `batch_shape = (4,)` as opposed to `batch_shape = (5, 4)`?
233+
# When training deep GPs with variational inference, there is an extra batch dimension added to the data, which
234+
# corresponds to the likelihood samples. We are testing if `alpha` broadcasts properly in these cases.
235+
kernel = RQKernel(batch_shape=torch.Size([4]))
236+
kernel.initialize(lengthscale=lengthscale, alpha=alpha)
237+
kernel.eval()
238+
239+
# First check the diagonal
240+
res = kernel(a, diag=True).to_dense()
241+
actual = torch.ones(5, 4, 3)
242+
self.assertAllClose(res, actual)
243+
244+
# Now check the kernel matrix on two different inputs
245+
res = kernel(a, b).to_dense()
246+
247+
scaled_a = a.div(lengthscale).unsqueeze(-2)
248+
scaled_b = b.div(lengthscale).unsqueeze(-3)
249+
dist = (scaled_a - scaled_b).square().sum(-1)
250+
251+
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
252+
actual = dist.div_(2 * alpha).add_(1).pow(-alpha)
253+
254+
self.assertAllClose(res, actual)
255+
225256

226257
if __name__ == "__main__":
227258
unittest.main()

0 commit comments

Comments
 (0)