DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

dsplit in PyTorch

Buy Me a Coffee

*Memos:

dsplit() can get the one or more 3D or more D depth-wisely splitted view tensors of zero or more elements from the 3D or more D tensor of zero or more elements as shown below:

*Memos:

  • dsplit() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is sections(Required-Type:int).
  • The 2nd argument with torch or the 1st argument with a tensor is indices(Required-Type:tuple of int or list of int).
  • The total number of the zero or more elements of one or more returned tensors changes.
  • One or more returned tensors keep the dimension of the original tensor.
import torch my_tensor = torch.tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]) torch.dsplit(input=my_tensor, sections=1) my_tensor.dsplit(sections=1) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]),)  torch.dsplit(input=my_tensor, sections=2) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, sections=4) # (tensor([[[0], [4], [8]]]), # tensor([[[1], [5], [9]]]), # tensor([[[2], [6], [10]]]), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(0,)) torch.dsplit(input=my_tensor, indices=(-4,)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(1,)) torch.dsplit(input=my_tensor, indices=(-3,)) # (tensor([[[0], [4], [8]]]), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(2,)) torch.dsplit(input=my_tensor, indices=(-2,)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(3,)) torch.dsplit(input=my_tensor, indices=(-1,)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(4,)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(0, 0)) torch.dsplit(input=my_tensor, indices=(0, -4)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(0, 1)) torch.dsplit(input=my_tensor, indices=(0, -3)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0], [4], [8]]]), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(0, 2)) torch.dsplit(input=my_tensor, indices=(0, -2)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(0, 3)) torch.dsplit(input=my_tensor, indices=(0, -1)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(0, 4)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(1, 0)) torch.dsplit(input=my_tensor, indices=(1, -4)) # (tensor([[[0], [4], [8]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(1, 1)) torch.dsplit(input=my_tensor, indices=(1, -3)) # (tensor([[[0], [4], [8]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(1, 2)) torch.dsplit(input=my_tensor, indices=(1, -2)) # (tensor([[[0], [4], [8]]]), # tensor([[[1], [5], [9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(1, 3)) torch.dsplit(input=my_tensor, indices=(1, -1)) # (tensor([[[0], [4], [8]]]), # tensor([[[1, 2], [5, 6], [9, 10]]]), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(1, 4)) # (tensor([[[0], [4], [8]]]), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(2, 0)) torch.dsplit(input=my_tensor, indices=(2, -4)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(2, 1)) torch.dsplit(input=my_tensor, indices=(2, -3)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(2, 2)) torch.dsplit(input=my_tensor, indices=(2, -2)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(2, 3)) torch.dsplit(input=my_tensor, indices=(2, -1)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2], [6], [10]]]), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(2, 4)) # (tensor([[[0, 1], [4, 5], [8, 9]]]), # tensor([[[2, 3], [6, 7], [10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(3, 0)) torch.dsplit(input=my_tensor, indices=(3, -4)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(3, 1)) torch.dsplit(input=my_tensor, indices=(3, -3)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(3, 2)) torch.dsplit(input=my_tensor, indices=(3, -2)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(3, 3)) torch.dsplit(input=my_tensor, indices=(3, -1)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(3, 4)) # (tensor([[[0, 1, 2], [4, 5, 6], [8, 9, 10]]]), # tensor([[[3], [7], [11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(4, 0)) torch.dsplit(input=my_tensor, indices=(4, -4)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(4, 1)) torch.dsplit(input=my_tensor, indices=(4, -3)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[1, 2, 3], [5, 6, 7], [9, 10, 11]]]))  torch.dsplit(input=my_tensor, indices=(4, 2)) torch.dsplit(input=my_tensor, indices=(4, -2)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[2, 3], [6, 7], [10, 11]]]))  torch.dsplit(input=my_tensor, indices=(4, 3)) torch.dsplit(input=my_tensor, indices=(4, -1)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[3], [7], [11]]]))  torch.dsplit(input=my_tensor, indices=(4, 4)) # (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([], size=(1, 3, 0), dtype=torch.int64))  torch.dsplit(input=my_tensor, indices=(0, 0, 0)) torch.dsplit(input=my_tensor, indices=(0, 0, -4)) torch.dsplit(input=my_tensor, indices=(0, -4, 0)) torch.dsplit(input=my_tensor, indices=(0, -4, -4)) # (tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([], size=(1, 3, 0), dtype=torch.int64), # tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]])) etc. my_tensor = torch.tensor([[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]]]) torch.dsplit(input=my_tensor, sections=1) # (tensor([[[0., 1., 2., 3.], # [4., 5., 6., 7.], # [8., 9., 10., 11.]]]),)  my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]) torch.dsplit(input=my_tensor, sections=1) # (tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], # [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], # [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]),)  my_tensor = torch.tensor([[[True, False, True, False], [False, True, False, True], [True, False, True, False]]]) torch.dsplit(input=my_tensor, sections=1) # (tensor([[[True, False, True, False], # [False, True, False, True], # [True, False, True, False]]]),) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)