Skip to content

Can't use gaussian_blur if sigma is a tensor on gpu #8401

@Xact-sniper

Description

@Xact-sniper

🐛 Describe the bug

Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.

It would work, except for this function:

def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:

x is not moved to the device that sigma is on.

I believe it is like this in all torchvision versions.

WORKS:

import torch from torchvision.transforms.functional import gaussian_blur k = 15 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True) blurred = gaussian_blur(torch.randn(1, 3, 256, 256), k, [s]) blurred.mean().backward() print(s.grad) >>> tensor(-4.6193e-05) 

DOES NOT:

import torch from torchvision.transforms.functional import gaussian_blur k = 15 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda') blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s]) blurred.mean().backward() print(s.grad) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) [D:\Temp\ipykernel_39000\3525683463.py](file:///D:/Temp/ipykernel_39000/3525683463.py) in <module> 4 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda') 5 ----> 6 blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s]) 7 blurred.mean().backward() 8 print(s.grad) [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\functional.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/functional.py) in gaussian_blur(img, kernel_size, sigma) 1361 t_img = pil_to_tensor(img) 1362 -> 1363 output = F_t.gaussian_blur(t_img, kernel_size, sigma) 1364 1365 if not isinstance(img, torch.Tensor): [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in gaussian_blur(img, kernel_size, sigma) 749 750 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 --> 751 kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) 752 kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) 753 [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel2d(kernel_size, sigma, dtype, device) 736 kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device 737 ) -> Tensor: --> 738 kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) 739 kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) 740 kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel1d(kernel_size, sigma) 727 728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) --> 729 pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 730 kernel1d = pdf / pdf.sum() 731 RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 

I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
to:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)

Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.

Versions

[pip3] efficientunet-pytorch==0.0.6
[pip3] ema-pytorch==0.4.5
[pip3] flake8==6.0.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.3
[pip3] numpydoc==1.4.0
[pip3] pytorch-msssim==1.0.0
[pip3] siren-pytorch==0.1.7
[pip3] torch==2.2.2+cu118
[pip3] torch-cluster==1.6.0+pt113cu116
[pip3] torch_geometric==2.4.0
[pip3] torch-scatter==2.1.0+pt113cu116
[pip3] torch-sparse==0.6.16+pt113cu116
[pip3] torch-spline-conv==1.2.1+pt113cu116
[pip3] torch-tools==0.1.5
[pip3] torchaudio==2.2.2+cu118
[pip3] torchbearer==0.5.3
[pip3] torchmeta==1.8.0
[pip3] torchvision==0.17.2+cu118
[pip3] uformer-pytorch==0.0.8
[pip3] vit-pytorch==1.5.0
[conda] blas 1.0 mkl
[conda] efficientunet-pytorch 0.0.6 pypi_0 pypi
[conda] ema-pytorch 0.4.5 pypi_0 pypi
[conda] mkl 2021.4.0 haa95532_640
[conda] mkl-service 2.4.0 py39h2bbff1b_0
[conda] mkl_fft 1.3.1 py39h277e83a_0
[conda] mkl_random 1.2.2 py39hf11a4ad_0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] numpydoc 1.4.0 py39haa95532_0
[conda] pytorch-cuda 11.6 h867d48c_1 pytorch
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] siren-pytorch 0.1.7 pypi_0 pypi
[conda] torch 1.13.0 pypi_0 pypi
[conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi
[conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi
[conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi
[conda] torch-tools 0.1.5 pypi_0 pypi
[conda] torchaudio 0.9.1 pypi_0 pypi
[conda] torchbearer 0.5.3 pypi_0 pypi
[conda] torchmeta 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.2+cu118 pypi_0 pypi
[conda] uformer-pytorch 0.0.8 pypi_0 pypi
[conda] vit-pytorch 1.5.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions