- Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🐛 Describe the bug
Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.
It would work, except for this function:
| def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: |
x is not moved to the device that sigma is on.
I believe it is like this in all torchvision versions.
WORKS:
import torch from torchvision.transforms.functional import gaussian_blur k = 15 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True) blurred = gaussian_blur(torch.randn(1, 3, 256, 256), k, [s]) blurred.mean().backward() print(s.grad) >>> tensor(-4.6193e-05) DOES NOT:
import torch from torchvision.transforms.functional import gaussian_blur k = 15 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda') blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s]) blurred.mean().backward() print(s.grad) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) [D:\Temp\ipykernel_39000\3525683463.py](file:///D:/Temp/ipykernel_39000/3525683463.py) in <module> 4 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda') 5 ----> 6 blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s]) 7 blurred.mean().backward() 8 print(s.grad) [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\functional.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/functional.py) in gaussian_blur(img, kernel_size, sigma) 1361 t_img = pil_to_tensor(img) 1362 -> 1363 output = F_t.gaussian_blur(t_img, kernel_size, sigma) 1364 1365 if not isinstance(img, torch.Tensor): [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in gaussian_blur(img, kernel_size, sigma) 749 750 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 --> 751 kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) 752 kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) 753 [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel2d(kernel_size, sigma, dtype, device) 736 kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device 737 ) -> Tensor: --> 738 kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) 739 kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) 740 kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) [s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel1d(kernel_size, sigma) 727 728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) --> 729 pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 730 kernel1d = pdf / pdf.sum() 731 RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
to:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)
Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.
Versions
[pip3] efficientunet-pytorch==0.0.6
[pip3] ema-pytorch==0.4.5
[pip3] flake8==6.0.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.3
[pip3] numpydoc==1.4.0
[pip3] pytorch-msssim==1.0.0
[pip3] siren-pytorch==0.1.7
[pip3] torch==2.2.2+cu118
[pip3] torch-cluster==1.6.0+pt113cu116
[pip3] torch_geometric==2.4.0
[pip3] torch-scatter==2.1.0+pt113cu116
[pip3] torch-sparse==0.6.16+pt113cu116
[pip3] torch-spline-conv==1.2.1+pt113cu116
[pip3] torch-tools==0.1.5
[pip3] torchaudio==2.2.2+cu118
[pip3] torchbearer==0.5.3
[pip3] torchmeta==1.8.0
[pip3] torchvision==0.17.2+cu118
[pip3] uformer-pytorch==0.0.8
[pip3] vit-pytorch==1.5.0
[conda] blas 1.0 mkl
[conda] efficientunet-pytorch 0.0.6 pypi_0 pypi
[conda] ema-pytorch 0.4.5 pypi_0 pypi
[conda] mkl 2021.4.0 haa95532_640
[conda] mkl-service 2.4.0 py39h2bbff1b_0
[conda] mkl_fft 1.3.1 py39h277e83a_0
[conda] mkl_random 1.2.2 py39hf11a4ad_0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] numpydoc 1.4.0 py39haa95532_0
[conda] pytorch-cuda 11.6 h867d48c_1 pytorch
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] siren-pytorch 0.1.7 pypi_0 pypi
[conda] torch 1.13.0 pypi_0 pypi
[conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi
[conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi
[conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi
[conda] torch-tools 0.1.5 pypi_0 pypi
[conda] torchaudio 0.9.1 pypi_0 pypi
[conda] torchbearer 0.5.3 pypi_0 pypi
[conda] torchmeta 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.2+cu118 pypi_0 pypi
[conda] uformer-pytorch 0.0.8 pypi_0 pypi
[conda] vit-pytorch 1.5.0 pypi_0 pypi