*Memos:
Unflatten() can add zero or more dimensions to the 1D or more D tensor of zero or more elements, getting the 1D or more D tensor of zero or more elements as shown below:
*Memos:
- The 1st argument for initialization is
dim
(Required-Type:int
). - The 2nd argument for initialization is
unflattened_size
(Required-Type:tuple
orlist
ofint
). - The 1st argument is
input
(Required-Type:tensor
ofint
,float
,complex
orbool
). *-1
infers and adjust the size. - The difference between Unflatten() and unflatten() is:
-
Unflatten()
hasunflattened_size
argument which is identical tosizes
argument ofunflatten()
. - Basically,
Unflatten()
is used to define a model whileunflatten()
is not used to define a model.
-
import torch from torch import nn unflatten = nn.Unflatten() unflatten # Unflatten(dim=0, unflattened_size=(6,)) unflatten.dim # 0 unflatten.unflattened_size # (6,) my_tensor = torch.tensor([7, 1, -8, 3, -6, 0]) unflatten = nn.Unflatten(dim=0, unflattened_size=(6,)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1,)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(6,)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1,)) unflatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0]) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 6)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1, 6)) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 6)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1, 6)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, -1)) unflatten(input=my_tensor) # tensor([[7, 1, -8, 3, -6, 0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2, 3)) unflatten = nn.Unflatten(dim=0, unflattened_size=(2, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(2, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(2, -1)) unflatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(3, 2)) unflatten = nn.Unflatten(dim=0, unflattened_size=(3, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(3, 2)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(3, -1)) unflatten(input=my_tensor) # tensor([[7, 1], [-8, 3], [-6, 0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(6, 1)) unflatten = nn.Unflatten(dim=0, unflattened_size=(6, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(6, 1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(6, -1)) unflatten(input=my_tensor) # tensor([[7], [1], [-8], [3], [-6], [0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 2, 3)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1, 2, 3)) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, -1, 3)) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 2, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 2, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1, 2, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, -1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 2, -1)) unflatten(input=my_tensor) # tensor([[[7, 1, -8], [3, -6, 0]]]) etc my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2,)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1,)) unflatten = nn.Unflatten(dim=1, unflattened_size=(3,)) unflatten = nn.Unflatten(dim=1, unflattened_size=(-1,)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(3,)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1,)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(2,)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(-1,)) unflatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 2)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1, 2)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(1, 2)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(-1, 2)) unflatten(input=my_tensor) # tensor([[[7, 1, -8], [3, -6, 0]]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2, 1)) unflatten = nn.Unflatten(dim=0, unflattened_size=(2, -1)) unflatten = nn.Unflatten(dim=1, unflattened_size=(1, 3)) unflatten = nn.Unflatten(dim=1, unflattened_size=(-1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1, 3)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(2, 1)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(2, -1)) unflatten(input=my_tensor) # tensor([[[7, 1, -8]], [[3, -6, 0]]]) unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 1)) unflatten = nn.Unflatten(dim=1, unflattened_size=(3, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(3, 1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(3, -1)) unflatten(input=my_tensor) # tensor([[[7], [1], [-8]], [[3], [-6], [0]]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 1, 2)) unflatten = nn.Unflatten(dim=0, unflattened_size=(-1, 1, 2)) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, -1, 2)) unflatten = nn.Unflatten(dim=0, unflattened_size=(1, 1, -1)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(1, 1, 2)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(-1, 1, 2)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(1, -1, 2)) unflatten = nn.Unflatten(dim=-2, unflattened_size=(1, 1, -1)) unflatten(input=my_tensor) # tensor([[[[7, 1, -8], [3, -6, 0]]]]) unflatten = nn.Unflatten(dim=1, unflattened_size=(1, 1, 3)) unflatten = nn.Unflatten(dim=1, unflattened_size=(-1, 1, 3)) unflatten = nn.Unflatten(dim=1, unflattened_size=(1, -1, 3)) unflatten = nn.Unflatten(dim=1, unflattened_size=(1, 1, -1)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(-1, 1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, -1, 3)) unflatten = nn.Unflatten(dim=-1, unflattened_size=(1, 1, -1)) unflatten(input=my_tensor) # tensor([[[[7, 1, -8]]], [[[3, -6, 0]]]]) my_tensor = torch.tensor([[7., 1., -8.], [3., -6., 0.]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2,)) unflatten(input=my_tensor) # tensor([[7., 1., -8.], [3., -6., 0.]]) my_tensor = torch.tensor([[7.+0.j, 1.+0.j, -8.+0.j], [3.+0.j, -6.+0.j, 0.+0.j]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2,)) unflatten(input=my_tensor) # tensor([[7.+0.j, 1.+0.j, -8.+0.j], # [3.+0.j, -6.+0.j, 0.+0.j]]) my_tensor = torch.tensor([[True, False, True], [False, True, False]]) unflatten = nn.Unflatten(dim=0, unflattened_size=(2,)) unflatten(input=my_tensor) # tensor([[True, False, True], [False, True, False]])
Top comments (0)