*Memos:
- My post explains stack().
- My post explains hstack() and column_stack().
- My post explains vstack() and dstack().
cat() can get the 1D or more D concatenated tensor of zero or more elements from the one or more 1D or more D tensors of zero or more elements as shown below:
*Memos:
-
cat()
can be used withtorch
but not with a tensor. - The 1st argument with
torch
istensors
(Required-Type:tuple
orlist
oftensor
ofint
,float
,complex
orbool
). *The size of tensors must be the same except dimension0
. - The 2nd argument with
torch
isdim
(Optional-Default:0
-Type:int
). - There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. - My post explains
out
argument.
-
- concat() is the alias of
cat()
.
import torch tensor1 = torch.tensor([2, 7, 4]) tensor2 = torch.tensor([8, 3, 2]) tensor3 = torch.tensor([5, 0, 8]) torch.cat(tensors=(tensor1, tensor2, tensor3)) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1) # tensor([2, 7, 4, 8, 3, 2, 5, 0, 8]) tensor1 = torch.tensor([2, 7]) tensor2 = torch.tensor([8, 3, 2]) tensor3 = torch.tensor([5]) torch.cat(tensors=(tensor1, tensor2, tensor3)) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1) # tensor([2, 7, 8, 3, 2, 5]) tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]]) tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2) # tensor([[2, 7, 4], # [8, 3, 2], # [5, 0, 8], # [3, 6, 1], # [9, 4, 7], # [1, 0, 5]]) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1) # tensor([[2, 7, 4, 5, 0, 8, 9, 4, 7], # [8, 3, 2, 3, 6, 1, 1, 0, 5]]) tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1], [9, 4, 7]]) tensor3 = torch.tensor([[1, 0, 5]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2) # tensor([[2, 7, 4], # [8, 3, 2], # [5, 0, 8], # [3, 6, 1], # [9, 4, 7], # [1, 0, 5]]) tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]], [[5, 0, 8], [3, 6, 1]]]) tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]], [[6, 7, 4], [2, 1, 9]]]) tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]], [[0, 8, 7], [3, 5, 2]]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-3) # tensor([[[2, 7, 4], [8, 3, 2]], # [[5, 0, 8], [3, 6, 1]], # [[9, 4, 7], [1, 0, 5]], # [[6, 7, 4], [2, 1, 9]], # [[1, 6, 3], [9, 6, 0]], # [[0, 8, 7], [3, 5, 2]]]) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2) # tensor([[[2, 7, 4], # [8, 3, 2], # [9, 4, 7], # [1, 0, 5], # [1, 6, 3], # [9, 6, 0]], # [[5, 0, 8], # [3, 6, 1], # [6, 7, 4], # [2, 1, 9], # [0, 8, 7], # [3, 5, 2]]]) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=2) torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1) # tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3], # [8, 3, 2, 1, 0, 5, 9, 6, 0]], # [[5, 0, 8, 6, 7, 4, 0, 8, 7], # [3, 6, 1, 2, 1, 9, 3, 5, 2]]]) tensor1 = torch.tensor([[[2., 7., 4.], [8., 3., 2.]], [[5., 0., 8.], [3., 6., 1.]]]) tensor2 = torch.tensor([[[9., 4., 7.], [1., 0., 5.]], [[6., 7., 4.], [2., 1., 9.]]]) tensor3 = torch.tensor([[[1., 6., 3.], [9., 6., 0.]], [[0., 8., 7.], [3., 5., 2.]]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) # tensor([[[2., 7., 4.], [8., 3., 2.]], # [[5., 0., 8.], [3., 6., 1.]], # [[9., 4., 7.], [1., 0., 5.]], # [[6., 7., 4.], [2., 1., 9.]], # [[1., 6., 3.], [9., 6., 0.]], # [[0., 8., 7.], [3., 5., 2.]]]) tensor1 = torch.tensor([[[2.+0.j, 7.+0.j, 4.+0.j], [8.+0.j, 3.+0.j, 2.+0.j]], [[5.+0.j, 0.+0.j, 8.+0.j], [3.+0.j, 6.+0.j, 1.+0.j]]]) tensor2 = torch.tensor([[[9.+0.j, 4.+0.j, 7.+0.j], [1.+0.j, 0.+0.j, 5.+0.j]], [[6.+0.j, 7.+0.j, 4.+0.j], [2.+0.j, 1.+0.j, 9.+0.j]]]) tensor3 = torch.tensor([[[1.+0.j, 6.+0.j, 3.+0.j], [9.+0.j, 6.+0.j, 0.+0.j]], [[0.+0.j, 8.+0.j, 7.+0.j], [3.+0.j, 5.+0.j, 2.+0.j]]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) # tensor([[[2.+0.j, 7.+0.j, 4.+0.j], # [8.+0.j, 3.+0.j, 2.+0.j]], # [[5.+0.j, 0.+0.j, 8.+0.j], # [3.+0.j, 6.+0.j, 1.+0.j]], # [[9.+0.j, 4.+0.j, 7.+0.j], # [1.+0.j, 0.+0.j, 5.+0.j]], # [[6.+0.j, 7.+0.j, 4.+0.j], # [2.+0.j, 1.+0.j, 9.+0.j]], # [[1.+0.j, 6.+0.j, 3.+0.j], # [9.+0.j, 6.+0.j, 0.+0.j]], # [[0.+0.j, 8.+0.j, 7.+0.j], # [3.+0.j, 5.+0.j, 2.+0.j]]]) tensor1 = torch.tensor([[[True, False, True], [True, False, True]], [[False, True, False], [False, True, False]]]) tensor2 = torch.tensor([[[False, True, False], [False, True, False]], [[True, False, True], [True, False, True]]]) tensor3 = torch.tensor([[[True, False, True], [True, False, True]], [[False, True, False], [False, True, False]]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) # tensor([[[True, False, True], [True, False, True]], # [[False, True, False], [False, True, False]], # [[False, True, False], [False, True, False]], # [[True, False, True], [True, False, True]], # [[True, False, True], [True, False, True]], # [[False, True, False], [False, True, False]]]) tensor1 = torch.tensor([[[0, 1, 2]]]) tensor2 = torch.tensor([]) tensor3 = torch.tensor([[[0, 1, 2]]]) torch.cat(tensors=(tensor1, tensor2, tensor3)) # tensor([[[0., 1., 2.]], # [[0., 1., 2.]]])
Top comments (0)