DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

chunk in PyTorch

Buy Me a Coffee

*Memos:

chunk() can get the one or more 1D or more D splitted view tensors of zero or more elements by specifying the number of chunks from the 1D or more D tensor of zero or more elements as shown below:

*Memos:

  • chunk() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is chunks(Required-Type:int).
  • The 3rd argument with torch or the 2nd argument with a tensor is dim(Optional-Default:0-Type:int).
  • The total number of the zero or more elements of one or more returned tensors doesn't change.
  • One or more returned tensors keep the dimension of the original tensor.
import torch my_tensor = torch.tensor([0, 1, 2, 3]) torch.chunk(input=my_tensor, chunks=1) my_tensor.chunk(chunks=1) torch.chunk(input=my_tensor, chunks=1, dim=0) torch.chunk(input=my_tensor, chunks=1, dim=-1) # (tensor([0, 1, 2, 3]),)  torch.chunk(input=my_tensor, chunks=2) torch.chunk(input=my_tensor, chunks=2, dim=0) torch.chunk(input=my_tensor, chunks=2, dim=-1) torch.chunk(input=my_tensor, chunks=3) torch.chunk(input=my_tensor, chunks=3, dim=0) torch.chunk(input=my_tensor, chunks=3, dim=-1) # (tensor([0, 1]), # tensor([2, 3]))  torch.chunk(input=my_tensor, chunks=4) torch.chunk(input=my_tensor, chunks=4, dim=0) torch.chunk(input=my_tensor, chunks=4, dim=-1) # (tensor([0]), tensor([1]), tensor([2]), tensor([3]))  my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) torch.chunk(input=my_tensor, chunks=1) torch.chunk(input=my_tensor, chunks=1, dim=0) torch.chunk(input=my_tensor, chunks=1, dim=1) torch.chunk(input=my_tensor, chunks=1, dim=-1) torch.chunk(input=my_tensor, chunks=1, dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)  torch.chunk(input=my_tensor, chunks=2) torch.chunk(input=my_tensor, chunks=2, dim=0) torch.chunk(input=my_tensor, chunks=2, dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]))  torch.chunk(input=my_tensor, chunks=2, dim=1) torch.chunk(input=my_tensor, chunks=2, dim=-1) torch.chunk(input=my_tensor, chunks=3, dim=1) torch.chunk(input=my_tensor, chunks=3, dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.chunk(input=my_tensor, chunks=3) torch.chunk(input=my_tensor, chunks=3, dim=0) torch.chunk(input=my_tensor, chunks=3, dim=-2) torch.chunk(input=my_tensor, chunks=4) torch.chunk(input=my_tensor, chunks=4, dim=0) torch.chunk(input=my_tensor, chunks=4, dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([[4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]))  torch.chunk(input=my_tensor, chunks=4, dim=1) torch.chunk(input=my_tensor, chunks=4, dim=-1) # (tensor([[0], [4], [8]]), # tensor([[1], [5], [9]]), # tensor([[2], [6], [10]]), # tensor([[3], [7], [11]]))  my_tensor = torch.tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]) torch.chunk(input=my_tensor, chunks=1) torch.chunk(input=my_tensor, chunks=1, dim=0) torch.chunk(input=my_tensor, chunks=1, dim=1) torch.chunk(input=my_tensor, chunks=1, dim=2) torch.chunk(input=my_tensor, chunks=1, dim=-1) torch.chunk(input=my_tensor, chunks=1, dim=-2) torch.chunk(input=my_tensor, chunks=1, dim=-3) torch.chunk(input=my_tensor, chunks=2) torch.chunk(input=my_tensor, chunks=2, dim=0) torch.chunk(input=my_tensor, chunks=2, dim=-3) torch.chunk(input=my_tensor, chunks=3) torch.chunk(input=my_tensor, chunks=3, dim=0) torch.chunk(input=my_tensor, chunks=3, dim=-3) torch.chunk(input=my_tensor, chunks=4) torch.chunk(input=my_tensor, chunks=4, dim=0) torch.chunk(input=my_tensor, chunks=4, dim=-3) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]),)  torch.chunk(input=my_tensor, chunks=2, dim=1) torch.chunk(input=my_tensor, chunks=2, dim=-2) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]), # tensor([[[8, 9, 10, 11]]]))  torch.chunk(input=my_tensor, chunks=2, dim=2) torch.chunk(input=my_tensor, chunks=2, dim=-1) torch.chunk(input=my_tensor, chunks=3, dim=2) torch.chunk(input=my_tensor, chunks=3, dim=-1) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.chunk(input=my_tensor, chunks=3, dim=1) torch.chunk(input=my_tensor, chunks=3, dim=-2) torch.chunk(input=my_tensor, chunks=4, dim=1) torch.chunk(input=my_tensor, chunks=4, dim=-2) # (tensor([[[0, 1, 2, 3]]]), # tensor([[[4, 5, 6, 7]]]), # tensor([[[8, 9, 10, 11]]]))  torch.chunk(input=my_tensor, chunks=4, dim=2) torch.chunk(input=my_tensor, chunks=4, dim=-1) # (tensor([[[0], [4], [8]]]), # tensor([[[1], [5], [9]]]), # tensor([[[2], [6], [10]]]), # tensor([[[3], [7], [11]]]))  my_tensor = torch.tensor([[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]]]) torch.chunk(input=my_tensor, chunks=1) # (tensor([[[0., 1., 2., 3.], # [4., 5., 6., 7.], # [8., 9., 10., 11.]]]),)  my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]) torch.chunk(input=my_tensor, chunks=1) # (tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], # [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], # [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]),)  my_tensor = torch.tensor([[[True, False, True, False], [False, True, False, True], [True, False, True, False]]]) torch.chunk(input=my_tensor, chunks=1) # (tensor([[[True, False, True, False], # [False, True, False, True], # [True, False, True, False]]]),) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)