DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

masked_select in PyTorch

Buy Me a Coffee

*Memos:

masked_select() can get the 1D tensor of the zero or more elements selected with zero or more masks from the 0D or more D tensor of zero or more elements as shown below:

*Memos:regularization

  • masked_select() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool). *It must be the 0D or more D tensor of zero or more elements.
  • The 2nd argument with torch or the 1st argument with a tensor is mask(Required-Type:tensor of bool). *It must be the 0D or more D tensor of zero or more boolean values.
  • There is out argument with torch(Optional-Default:None-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
import torch my_tensor = torch.tensor([8, -3, 0, 1, 5, -2]) torch.masked_select(input=my_tensor, mask=torch.tensor([False, True, True, False, True, False])) my_tensor.masked_select( mask=torch.tensor([False, True, True, False, True, False])) # tensor([-3, 0, 5])  torch.masked_select(input=my_tensor, mask=torch.tensor(True)) torch.masked_select(input=my_tensor, mask=torch.tensor([True, True, True, True, True, True])) # tensor([8, -3, 0, 1, 5, -2])  torch.masked_select(input=my_tensor, mask=torch.tensor(False)) torch.masked_select(input=my_tensor, mask=torch.tensor([False, False, False, False, False, False])) # tensor([], dtype=torch.int64)  my_tensor = torch.tensor([[8, -3, 0], [1, 5, -2]]) torch.masked_select(input=my_tensor, mask=torch.tensor([[False, True, True], [False, True, False]])) # tensor([-3, 0, 5])  torch.masked_select(input=my_tensor, mask=torch.tensor(True)) # tensor([8, -3, 0, 1, 5, -2])  torch.masked_select(input=my_tensor, mask=torch.tensor(False)) # tensor([], dtype=torch.int64)  my_tensor = torch.tensor([[[8], [-3], [0]], [[1], [5], [-2]]]) torch.masked_select(input=my_tensor, mask=torch.tensor([[[False], [True], [True]], [[False], [True], [False]]])) # tensor([-3, 0, 5])  torch.masked_select(input=my_tensor, mask=torch.tensor(True)) # tensor([8, -3, 0, 1, 5, -2])  torch.masked_select(input=my_tensor, mask=torch.tensor(False)) # tensor([], dtype=torch.int64)  my_tensor = torch.tensor([[[8.], [-3.], [0.]], [[1.], [5.], [-2.]]]) torch.masked_select(input=my_tensor, mask=torch.tensor([[[False], [True], [True]], [[False], [True], [False]]])) # tensor([-3., 0., 5.])  my_tensor = torch.tensor([[[8.+0.j], [-3.+0.j], [0.+0.j]], [[1.+0.j], [5.+0.j], [-2.+0.j]]]) torch.masked_select(input=my_tensor, mask=torch.tensor([[[False], [True], [True]], [[False], [True], [False]]])) # tensor([-3.+0.j, 0.+0.j, 5.+0.j])  my_tensor = torch.tensor([[[True], [False], [True]], [[False], [True], [False]]]) torch.masked_select(input=my_tensor, mask=torch.tensor([[[False], [True], [True]], [[False], [True], [False]]])) # tensor([False, True, True]) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)