DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

take and take_along_dim in PyTorch

Buy Me a Coffee

*My post explains gather().

take() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • take() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is index(Required-Type:tensor of int). *It decides the size of a returned tensor.
import torch my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]]) torch.take(input=my_tensor, index=torch.tensor(3)) my_tensor.take(index=torch.tensor(3)) torch.take(input=my_tensor, index=torch.tensor(-7)) # tensor(6)  torch.take(input=my_tensor, index=torch.tensor([3, 0, 7, 4])) torch.take(input=my_tensor, index=torch.tensor([-7, -10, -3, -6])) # tensor([6, 9, 3, 2])  torch.take(input=my_tensor, index=torch.tensor([[3, 0], [7, 4]])) torch.take(input=my_tensor, index=torch.tensor([[-7, -10], [-3, -6]])) # tensor([[6, 9], [3, 2]])  torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) torch.take(input=my_tensor, index=torch.tensor([[[-7, -10], [-3, -6]], [[-2, -8], [-7, -5]]])) # tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]])  my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[6., 9.], [3., 2.]], [[4., 0.], [6., 7.]]])  my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j], [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[6.+0.j, 9.+0.j], [3.+0.j, 2.+0.j]], # [[4.+0.j, 0.+0.j], [6.+0.j, 7.+0.j]]])  my_tensor = torch.tensor([[True, False, True, False, True], [False, True, False, True, False]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[False, True], [False, True]], # [[True, True], [False, False]]]) 
Enter fullscreen mode Exit fullscreen mode

take_along_dim() can get the 1D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • take_along_dim() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is indices(Required-Type:tensor of int).
  • The 3rd argument with torch or the 2nd argument with a tensor is dim(Optional-Type:int): *Memos:
    • Not setting dim returns a 1D tensor.
    • If dim is set, both tensors must be the same D and the returned tensor is its D.
  • There is out argument with torch(Optional-Default:None-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
import torch my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor(3)) my_tensor.take_along_dim(indices=torch.tensor(3)) torch.gather(input=my_tensor, indices=torch.tensor(3)) # tensor([6])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([3, 0, 7, 4])) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0], [7, 4]])) # tensor([6, 9, 3, 2])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([6, 9, 3, 2, 4, 0, 6, 7])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0], [1], [0], [1]]), dim=0) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0], [1], [0], [1]]), dim=-2) # tensor([[9, 5, 0, 6, 2], # [7, 1, 3, 4, 8], # [9, 5, 0, 6, 2], # [7, 1, 3, 4, 8]])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 0, 1, 0, 1]]), dim=0) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 0, 1, 0, 1]]), dim=-2) # tensor([[9, 5, 0, 6, 2], # [7, 1, 3, 4, 8], # [9, 1, 0, 4, 2], # [7, 5, 3, 6, 8]])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4]]), dim=1) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4]]), dim=-1) # tensor([[6, 9, 2], [4, 7, 8]])  torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=-1) # tensor([[6, 9, 2], [8, 1, 3]])  my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[6., 9., 2.], [8., 1., 3.]])  my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j], [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[6.+0.j, 9.+0.j, 2.+0.j], # [8.+0.j, 1.+0.j, 3.+0.j]])  my_tensor = torch.tensor([[True, False, True, False, True], [False, True, False, True, False]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[False, True, True], # [False, True, False]]) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)