take() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
-
take()
can be used with torch or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 1st argument with a tensor isindex
(Required-Type:tensor
ofint
). *It decides the size of a returned tensor.
import torch my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]]) torch.take(input=my_tensor, index=torch.tensor(3)) my_tensor.take(index=torch.tensor(3)) torch.take(input=my_tensor, index=torch.tensor(-7)) # tensor(6) torch.take(input=my_tensor, index=torch.tensor([3, 0, 7, 4])) torch.take(input=my_tensor, index=torch.tensor([-7, -10, -3, -6])) # tensor([6, 9, 3, 2]) torch.take(input=my_tensor, index=torch.tensor([[3, 0], [7, 4]])) torch.take(input=my_tensor, index=torch.tensor([[-7, -10], [-3, -6]])) # tensor([[6, 9], [3, 2]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) torch.take(input=my_tensor, index=torch.tensor([[[-7, -10], [-3, -6]], [[-2, -8], [-7, -5]]])) # tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]]) my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[6., 9.], [3., 2.]], [[4., 0.], [6., 7.]]]) my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j], [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[6.+0.j, 9.+0.j], [3.+0.j, 2.+0.j]], # [[4.+0.j, 0.+0.j], [6.+0.j, 7.+0.j]]]) my_tensor = torch.tensor([[True, False, True, False, True], [False, True, False, True, False]]) torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([[[False, True], [False, True]], # [[True, True], [False, False]]])
take_along_dim() can get the 1D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
-
take_along_dim()
can be used withtorch
or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 1st argument with a tensor isindices
(Required-Type:tensor
ofint
). - The 3rd argument with
torch
or the 2nd argument with a tensor isdim
(Optional-Type:int
): *Memos:- Not setting
dim
returns a 1D tensor. - If
dim
is set, both tensors must be the same D and the returned tensor is its D.
- Not setting
- There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. - My post explains
out
argument.
-
import torch my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor(3)) my_tensor.take_along_dim(indices=torch.tensor(3)) torch.gather(input=my_tensor, indices=torch.tensor(3)) # tensor([6]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([3, 0, 7, 4])) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0], [7, 4]])) # tensor([6, 9, 3, 2]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[[3, 0], [7, 4]], [[8, 2], [3, 5]]])) # tensor([6, 9, 3, 2, 4, 0, 6, 7]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0], [1], [0], [1]]), dim=0) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0], [1], [0], [1]]), dim=-2) # tensor([[9, 5, 0, 6, 2], # [7, 1, 3, 4, 8], # [9, 5, 0, 6, 2], # [7, 1, 3, 4, 8]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 0, 1, 0, 1]]), dim=0) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 0, 1, 0, 1]]), dim=-2) # tensor([[9, 5, 0, 6, 2], # [7, 1, 3, 4, 8], # [9, 1, 0, 4, 2], # [7, 5, 3, 6, 8]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4]]), dim=1) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4]]), dim=-1) # tensor([[6, 9, 2], [4, 7, 8]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=-1) # tensor([[6, 9, 2], [8, 1, 3]]) my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[6., 9., 2.], [8., 1., 3.]]) my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j], [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[6.+0.j, 9.+0.j, 2.+0.j], # [8.+0.j, 1.+0.j, 3.+0.j]]) my_tensor = torch.tensor([[True, False, True, False, True], [False, True, False, True, False]]) torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0, 4], [4, 1, 2]]), dim=1) # tensor([[False, True, True], # [False, True, False]])
Top comments (0)