DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

repeat_interleave in PyTorch

Buy Me a Coffee

*Memos:

repeat_interleave() can get the 1D tensor of zero or more immediately repeated elements from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • repeat_interleave() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Optional-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is repeats(Required-Type:int or tensor of int, float, complex or bool). *The tensor must be 0D or 1D.
  • The 3rd argument with torch or the 2nd argument with a tensor is dim(Optional-Type:int).
  • There is output_size argument with torch or a tensor(Optional-Default:None-Type:int): *Memos:
    • Total output size for the given axis (e.g. sum of repeats). If given, it will avoid stream synchronization needed to calculate output shape of the tensor.
    • output_size= must be used.
import torch my_tensor = torch.tensor([7, 4, 2]) torch.repeat_interleave(repeats=my_tensor) # tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2])  torch.repeat_interleave(input=my_tensor, repeats=my_tensor) my_tensor.repeat_interleave(repeats=my_tensor) torch.repeat_interleave(input=my_tensor, repeats=my_tensor, dim=0) torch.repeat_interleave(input=my_tensor, repeats=my_tensor, dim=-1) # tensor([7, 7, 7, 7, 7, 7, 7, 4, 4, 4, 4, 2, 2])  torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2, 1, 4])) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2, 1, 4]), dim=0) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2, 1, 4]), dim=-1) # tensor([7, 7, 4, 2, 2, 2, 2])  torch.repeat_interleave(input=my_tensor, repeats=torch.tensor(2)) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor(2), dim=0) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor(2), dim=-1) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2])) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2]), dim=0) torch.repeat_interleave(input=my_tensor, repeats=torch.tensor([2]), dim=-1) # tensor([7, 7, 4, 4, 2, 2])  torch.repeat_interleave(input=my_tensor, repeats=0) torch.repeat_interleave(input=my_tensor, repeats=0, dim=0) torch.repeat_interleave(input=my_tensor, repeats=0, dim=-1) # tensor([], dtype=torch.int64)  torch.repeat_interleave(input=my_tensor, repeats=1) torch.repeat_interleave(input=my_tensor, repeats=1, dim=0) torch.repeat_interleave(input=my_tensor, repeats=1, dim=-1) # tensor([7, 4, 2])  torch.repeat_interleave(input=my_tensor, repeats=2) torch.repeat_interleave(input=my_tensor, repeats=2, dim=0) torch.repeat_interleave(input=my_tensor, repeats=2, dim=-1) # tensor([7, 7, 4, 4, 2, 2])  torch.repeat_interleave(input=my_tensor, repeats=3) torch.repeat_interleave(input=my_tensor, repeats=3, dim=0) torch.repeat_interleave(input=my_tensor, repeats=3, dim=-1) # tensor([7, 7, 7, 4, 4, 4, 2, 2, 2]) etc. torch.repeat_interleave(input=my_tensor, repeats=3, dim=0, output_size=9) # tensor([7, 7, 7, 4, 4, 4, 2, 2, 2])  my_tensor = torch.tensor([[7, 4, 2], [5, 1, 6]]) torch.repeat_interleave(input=my_tensor, repeats=1) # tensor([7, 4, 2, 5, 1, 6])  torch.repeat_interleave(input=my_tensor, repeats=1, dim=0) torch.repeat_interleave(input=my_tensor, repeats=1, dim=1) torch.repeat_interleave(input=my_tensor, repeats=1, dim=-1) torch.repeat_interleave(input=my_tensor, repeats=1, dim=-2) # tensor([[7, 4, 2], [5, 1, 6]])  torch.repeat_interleave(input=my_tensor, repeats=2) # tensor([7, 7, 4, 4, 2, 2, 5, 5, 1, 1, 6, 6])  torch.repeat_interleave(input=my_tensor, repeats=2, dim=0) torch.repeat_interleave(input=my_tensor, repeats=2, dim=-2) # tensor([[7, 4, 2], [7, 4, 2], [5, 1, 6], [5, 1, 6]])  torch.repeat_interleave(input=my_tensor, repeats=2, dim=1) torch.repeat_interleave(input=my_tensor, repeats=2, dim=-1) # tensor([[7, 7, 4, 4, 2, 2], [5, 5, 1, 1, 6, 6]])  torch.repeat_interleave(input=my_tensor, repeats=3) # tensor([7, 7, 7, 4, 4, 4, 2, 2, 2, 5, 5, 5, 1, 1, 1, 6, 6, 6])  torch.repeat_interleave(input=my_tensor, repeats=3, dim=0) torch.repeat_interleave(input=my_tensor, repeats=3, dim=-2) # tensor([[7, 4, 2], [7, 4, 2], [7, 4, 2], [5, 1, 6], [5, 1, 6], [5, 1, 6]]) torch.repeat_interleave(input=my_tensor, repeats=3, dim=1) torch.repeat_interleave(input=my_tensor, repeats=3, dim=-1) # tensor([[7, 7, 7, 4, 4, 4, 2, 2, 2], [5, 5, 5, 1, 1, 1, 6, 6, 6]])  my_tensor = torch.tensor([[7., 4., 2.], [5., 1., 6.]]) torch.repeat_interleave(input=my_tensor, repeats=1) # tensor([7., 4., 2., 5., 1., 6.])  my_tensor = torch.tensor([[7.+0.j, 4.+0.j, 2.+0.j], [5.+0.j, 1.+0.j, 6.+0.j]]) torch.repeat_interleave(input=my_tensor, repeats=1) # tensor([7.+0.j, 4.+0.j, 2.+0.j, 5.+0.j, 1.+0.j, 6.+0.j])  my_tensor = torch.tensor([[True, False, True], [False, True, False]]) torch.repeat_interleave(input=my_tensor, repeats=1) # tensor([True, False, True, False, True, False]) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)