DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

unsqueeze in PyTorch

Buy Me a Coffee

*My post explains squeeze().

unsqueeze() can get the 1D or more D tensor of zero or more elements with additional dimension whose size is 1 from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • unsqueeze() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is dim(Required-Type:int). *It can add the dimension whose size is 1 to a specific position.
import torch my_tensor = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [10, 11, 12]]) torch.unsqueeze(input=my_tensor, dim=0) my_tensor.unsqueeze(dim=0) torch.unsqueeze(input=my_tensor, dim=-3) # tensor([[[0, 1, 2], # [3, 4, 5], # [6, 7, 8] # [10, 11, 12]]])  torch.unsqueeze(input=my_tensor, dim=1) torch.unsqueeze(input=my_tensor, dim=-2) # tensor([[[0, 1, 2]], # [[3, 4, 5]], # [[6, 7, 8]] # [[10, 11, 12]]])  torch.unsqueeze(input=my_tensor, dim=2) torch.unsqueeze(input=my_tensor, dim=-1) # tensor([[[0], [1], [2]], # [[3], [4], [5]], # [[6], [7], [8]], # [[10], [11], [12]]])  torch.unsqueeze(input=my_tensor, dim=3) torch.unsqueeze(input=my_tensor, dim=-1) # tensor([[[[0], [1], [2], [3]], [[4], [5], [6], [7]]], # [[[8], [9], [10], [11]], [[12], [13], [14], [15]]], # [[[16], [17], [18], [19]], [[20], [21], [22], [23]]]])  my_tensor = torch.tensor([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [10., 11., 12.]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[0., 1., 2.], # [3., 4., 5.], # [6., 7., 8.], # [10., 11., 12.]]])  my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j], [3.+0.j, 4.+0.j, 5.+0.j], [6.+0.j, 7.+0.j, 8.+0.j], [10.+0.j, 11.+0.j, 12.+0.j]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[0.+0.j, 1.+0.j, 2.+0.j], # [3.+0.j, 4.+0.j, 5.+0.j], # [6.+0.j, 7.+0.j, 8.+0.j], # [10.+0.j, 11.+0.j, 12.+0.j]]])  my_tensor = torch.tensor([[True, False, True], [False, True, False], [True, False, True], [False, True, False]]) torch.unsqueeze(input=my_tensor, dim=0) # tensor([[[True, False, True], # [False, True, False], # [True, False, True], # [False, True, False]]]) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)