DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

index_select in PyTorch

Buy Me a Coffee

*Memos:

index_select() can get the 0D or more D tensor of the zero or more elements selected with zero or more indices, not removing one dimension from the 0D or more D tensor of zero or more elements as shown below:

*Memos:regularization

  • index_select() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool). *It must be the 0D or more D tensor of zero or more elements.
  • The 2nd argument with torch or the 1st argument with a tensor is dim(Required-Type:int).
  • The 3rd argument with torch or the 2nd argument with a tensor is index(Required-Type:tensor of int). *It must be the 0D or 1D tensor of zero or more integers.
  • There is out argument with torch(Optional-Default:None-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
import torch my_tensor = torch.tensor([8, -3, 0, 1, 5, -2, -1, 4]) torch.index_select(input=my_tensor, dim=0, index=torch.tensor(4)) my_tensor.index_select(dim=0, index=torch.tensor(4)) torch.index_select(input=my_tensor, dim=-1, index=torch.tensor(4)) # tensor([5])  torch.index_select(input=my_tensor, dim=0, index=torch.tensor([5, 2, 0, 7])) torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([5, 2, 0, 7])) # tensor([-2, 0, 8, 4])  my_tensor = torch.tensor([[8, -3, 0, 1], [5, -2, -1, 4]]) torch.index_select(input=my_tensor, dim=0, index=torch.tensor(1)) torch.index_select(input=my_tensor, dim=0, index=torch.tensor([1])) torch.index_select(input=my_tensor, dim=-2, index=torch.tensor(1)) torch.index_select(input=my_tensor, dim=-2, index=torch.tensor([1])) # tensor([[5, -2, -1, 4]])  torch.index_select(input=my_tensor, dim=0, index=torch.tensor([1, 0, 0, 1])) torch.index_select(input=my_tensor, dim=-2, index=torch.tensor([1, 0, 0, 1])) # tensor([[5, -2, -1, 4], # [8, -3, 0, 1], # [8, -3, 0, 1], # [5, -2, -1, 4]])  torch.index_select(input=my_tensor, dim=1, index=torch.tensor([3, 1, 2])) torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([3, 1, 2])) # tensor([[1, -3, 0], # [4, -2, -1]])  my_tensor = torch.tensor([[[8, -3], [0, 1]], [[5, -2], [-1, 4]]]) torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1)) torch.index_select(input=my_tensor, dim=2, index=torch.tensor([1])) torch.index_select(input=my_tensor, dim=-1, index=torch.tensor(1)) torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([1])) # tensor([[[-3], [1]], # [[-2], [4]]])  my_tensor = torch.tensor([[[8., -3.], [0., 1.]], [[5., -2.], [-1., 4.]]]) torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1)) # tensor([[[-3.], [1.]], # [[-2.], [4.]]])  my_tensor = torch.tensor([[[8.+0.j, -3.+0.j], [0.+0.j, 1.+0.j]], [[5.+0.j, -2.+0.j], [-1.+0.j, 4.+0.j]]]) torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1)) # tensor([[[-3.+0.j], [1.+0.j]], # [[-2.+0.j], [4.+0.j]]])  my_tensor = torch.tensor([[[True, False], [True, False]], [[False, True], [False, True]]]) torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1)) # tensor([[[False], [False]], # [[True], [True]]]) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)