DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

tensor_split in PyTorch

Buy Me a Coffee

*Memos:

tensor_split() can get the one or more 1D or more D tensors of zero or more splitted elements from the 1D or more D tensor of zero or more elements as shown below:

*Memos:

  • tensor_split() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool). *It must be a 1D or more D tensor.
  • The 2nd argument with torch or the 1st argument with a tensor is sections(Required-Type:int).
  • The 2nd argument with torch or the 1st argument with a tensor is indices(Required-Type:tuple of int or list of int).
  • The 2nd argument with torch or the 1st argument with a tensor is tensor_indices_or_sections(Required-Type:tensor of int). *It must be a 0D or 1D tensor.
  • The 3rd argument with torch or the 2nd argument with a tensor is dim(Optional-Default:0-Type:int).
  • The number of the zero or more elements of a tensor changes.
  • The total number of the zero or more elements of one or more returned tensors changes.
  • One or more returned tensors keep the dimension of input tensor.
import torch my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) torch.tensor_split(input=my_tensor, sections=1) my_tensor.tensor_split(sections=1) torch.tensor_split(input=my_tensor, sections=1, dim=0) torch.tensor_split(input=my_tensor, sections=1, dim=1) torch.tensor_split(input=my_tensor, sections=1, dim=-1) torch.tensor_split(input=my_tensor, sections=1, dim=-2) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor(1), dim=0) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor(1), dim=1) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor(1), dim=-1) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor(1), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)  torch.tensor_split(input=my_tensor, indices=(1,)) torch.tensor_split(input=my_tensor, indices=(1,), dim=0) torch.tensor_split(input=my_tensor, indices=(1,), dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1,), dim=1) torch.tensor_split(input=my_tensor, indices=(1,), dim=-1) # (tensor([[0], [4], [8]]), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))  torch.tensor_split(input=my_tensor, sections=2) torch.tensor_split(input=my_tensor, indices=(2,)) torch.tensor_split(input=my_tensor, sections=2, dim=0) torch.tensor_split(input=my_tensor, indices=(2,), dim=0) torch.tensor_split(input=my_tensor, sections=2, dim=-2) torch.tensor_split(input=my_tensor, indices=(2,), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1,)) torch.tensor_split(input=my_tensor, indices=(-1,), dim=0) torch.tensor_split(input=my_tensor, indices=(-1,), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, sections=2, dim=1) torch.tensor_split(input=my_tensor, indices=(2,), dim=1) torch.tensor_split(input=my_tensor, sections=2, dim=-1) torch.tensor_split(input=my_tensor, indices=(2,), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.tensor_split(input=my_tensor, sections=3) torch.tensor_split(input=my_tensor, sections=3, dim=0) torch.tensor_split(input=my_tensor, sections=3, dim=-2) torch.tensor_split(input=my_tensor, indices=(1, 2)) torch.tensor_split(input=my_tensor, indices=(1, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(1, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(1, -1)) torch.tensor_split(input=my_tensor, indices=(1, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(1, -1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, 2)) torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, -1)) torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=-2) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor([1, 2]), dim=0) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor([1, 2]), dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([[4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3,), dim=0) torch.tensor_split(input=my_tensor, indices=(3,), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(3,), dim=1) torch.tensor_split(input=my_tensor, indices=(3,), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1,), dim=1) torch.tensor_split(input=my_tensor, indices=(-1,), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([[3], [7], [11]]))  torch.tensor_split(input=my_tensor, sections=3, dim=1) torch.tensor_split(input=my_tensor, sections=3, dim=-1) torch.tensor_split(input=my_tensor, indices=(2, 3), dim=1) torch.tensor_split(input=my_tensor, indices=(2, 3), dim=-1) torch.tensor_split(input=my_tensor, indices=(2, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(2, -1), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([[2], [6], [10]]), # tensor([[3], [7], [11]]))  torch.tensor_split(input=my_tensor, indices=(0, 0)) torch.tensor_split(input=my_tensor, indices=(0, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(0, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(0, -3)) torch.tensor_split(input=my_tensor, indices=(0, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(0, -3), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, 0)) torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, -3)) torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=-2) torch.tensor_split(input=my_tensor, indices=(-4, -4)) torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=0) torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=-2) # (tensor([], size=(0, 4), dtype=torch.int64), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(0, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(0, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(0, -4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=-1) # (tensor([], size=(3, 0), dtype=torch.int64), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 1)) torch.tensor_split(input=my_tensor, indices=(0, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(0, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(0, -2)) torch.tensor_split(input=my_tensor, indices=(0, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(0, -2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, 1)) torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, -2)) torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=-2) # (tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3]]), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(0, 1), dim=-1) torch.tensor_split(input=my_tensor, indices=(0, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(0, -3), dim=-1) # (tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0], [4], [8]]), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 2)) torch.tensor_split(input=my_tensor, indices=(0, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(0, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(0, -1)) torch.tensor_split(input=my_tensor, indices=(0, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(0, -1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, 2)) torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, -1)) torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=-2) # (tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(0, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(0, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(0, -2), dim=-1) # (tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1], [4, 5], [8, 9]]), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 3)) torch.tensor_split(input=my_tensor, indices=(0, 3), dim=0) torch.tensor_split(input=my_tensor, indices=(0, 3), dim=-2) torch.tensor_split(input=my_tensor, indices=(-3, 3)) torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=0) torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=-2) # (tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(0, 3), dim=1) torch.tensor_split(input=my_tensor, indices=(0, 3), dim=-1) torch.tensor_split(input=my_tensor, indices=(0, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(0, -1), dim=-1) # (tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([[3], [7], [11]]))  torch.tensor_split(input=my_tensor, indices=(0, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(0, 4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-4, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(-4, 4), dim=-1) # (tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(3, 0), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(1, 0)) torch.tensor_split(input=my_tensor, indices=(1, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(1, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(1, -3)) torch.tensor_split(input=my_tensor, indices=(1, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(1, -3), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, 0)) torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, -3)) torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(1, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(1, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(1, -4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, -4), dim=-1) # (tensor([[0], [4], [8]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1, 1)) torch.tensor_split(input=my_tensor, indices=(1, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(1, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(1, -2)) torch.tensor_split(input=my_tensor, indices=(1, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(1, -2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, 1)) torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-2, -2)) torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(1, 1), dim=-1) torch.tensor_split(input=my_tensor, indices=(1, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(1, -3), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=-1) # (tensor([[0], [4], [8]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(1, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(1, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(1, -2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=-1) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor([1, 2]), dim=1) torch.tensor_split(input=my_tensor, tensor_indices_or_sections=torch.tensor([1, 2]), dim=-1) # (tensor([[0], [4], [8]]), # tensor([[1], [5], [9]]), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.tensor_split(input=my_tensor, indices=(1, 3)) torch.tensor_split(input=my_tensor, indices=(1, 3), dim=0) torch.tensor_split(input=my_tensor, indices=(1, 3), dim=-2) # (tensor([[0, 1, 2, 3]]), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(1, 3), dim=1) torch.tensor_split(input=my_tensor, indices=(1, 3), dim=-1) torch.tensor_split(input=my_tensor, indices=(1, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(1, -1), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=-1) # (tensor([[0], [4], [8]]), # tensor([[1, 2], [5, 6], [9, 10]]), # tensor([[3], [7], [11]]))  torch.tensor_split(input=my_tensor, indices=(1, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(1, 4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-3, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(-3, 4), dim=-1) # (tensor([[0], [4], [8]]), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]), # tensor([], size=(3, 0), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(2, 0)) torch.tensor_split(input=my_tensor, indices=(2, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(2, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(2, -3)) torch.tensor_split(input=my_tensor, indices=(2, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(2, -3), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1, 0)) torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(2, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(2, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(2, -4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, -4), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, 1)) torch.tensor_split(input=my_tensor, indices=(2, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(2, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(2, -2)) torch.tensor_split(input=my_tensor, indices=(2, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(2, -2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1, 1)) torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1, -2)) torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=0) torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, 2)) torch.tensor_split(input=my_tensor, indices=(2, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(2, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(2, -1)) torch.tensor_split(input=my_tensor, indices=(2, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(2, -1), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1, 2)) torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(-1, -1)) torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(2, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(2, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(2, -2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, 3)) torch.tensor_split(input=my_tensor, indices=(2, 3), dim=0) torch.tensor_split(input=my_tensor, indices=(2, 3), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), # tensor([[8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(2, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(2, 4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, 4), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([[2, 3], [6, 7], [10, 11]]), # tensor([], size=(3, 0), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(3, 0)) torch.tensor_split(input=my_tensor, indices=(3, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(3, 0), dim=-2) torch.tensor_split(input=my_tensor, indices=(3, -3)) torch.tensor_split(input=my_tensor, indices=(3, -3), dim=0) torch.tensor_split(input=my_tensor, indices=(3, -3), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(3, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(3, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(3, -4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, -4), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 1)) torch.tensor_split(input=my_tensor, indices=(3, 1), dim=0) torch.tensor_split(input=my_tensor, indices=(3, 1), dim=-2) torch.tensor_split(input=my_tensor, indices=(3, -2)) torch.tensor_split(input=my_tensor, indices=(3, -2), dim=0) torch.tensor_split(input=my_tensor, (3, -2), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(3, 1), dim=-1) torch.tensor_split(input=my_tensor, indices=(3, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(3, -3), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 2)) torch.tensor_split(input=my_tensor, indices=(3, 2), dim=0) torch.tensor_split(input=my_tensor, indices=(3, 2), dim=-2) torch.tensor_split(input=my_tensor, indices=(3, -1)) torch.tensor_split(input=my_tensor, indices=(3, -1), dim=0) torch.tensor_split(input=my_tensor, indices=(3, -1), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(3, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(3, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(3, -2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[2, 3], [6, 7], [10, 11]]))  torch.tensor_split(input=my_tensor, indices=(3, 3)) torch.tensor_split(input=my_tensor, indices=(3, 3), dim=0) torch.tensor_split(input=my_tensor, indices=(3, 3), dim=-2) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([], size=(0, 4), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(3, 3), dim=1) torch.tensor_split(input=my_tensor, indices=(3, 3), dim=-1) torch.tensor_split(input=my_tensor, indices=(3, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(3, -1), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[3], [7], [11]]))  torch.tensor_split(input=my_tensor, indices=(3, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(3, 4), dim=-1) torch.tensor_split(input=my_tensor, indices=(-1, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(-1, 4), dim=-1) # (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]), # tensor([[3], [7], [11]]), # tensor([], size=(3, 0), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(4, 4), dim=1) torch.tensor_split(input=my_tensor, indices=(4, 4), dim=-1) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([], size=(3, 0), dtype=torch.int64))  torch.tensor_split(input=my_tensor, indices=(4, -4), dim=1) torch.tensor_split(input=my_tensor, indices=(4, -4), dim=-1) # (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(2, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(2, -3), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=-1) torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=1) torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=-1) # (tensor([[0, 1], [4, 5], [8, 9]]), # tensor([], size=(3, 0), dtype=torch.int64), # tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))  torch.tensor_split(input=my_tensor, indices=(0, 0, 0)) torch.tensor_split(input=my_tensor, indices=(0, 0, 0), dim=0) torch.tensor_split(input=my_tensor, indices=(0, 0, 0), dim=-2) # (tensor([], size=(0, 4), dtype=torch.int64), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([], size=(0, 4), dtype=torch.int64), # tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])) etc. my_tensor = torch.tensor([[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]]) torch.tensor_split(input=my_tensor, sections=1) # (tensor([[0., 1., 2., 3.], # [4., 5., 6., 7.], # [8., 9., 10., 11.]]),)  my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]) torch.tensor_split(input=my_tensor, sections=1) # (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], # [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], # [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)  my_tensor = torch.tensor([[True, False, True, False], [False, True, False, True], [True, False, True, False]]) torch.tensor_split(input=my_tensor, sections=1) # (tensor([[True, False, True, False], # [False, True, False, True], # [True, False, True, False]]),) 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)