DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

MovingMNIST in PyTorch

Buy Me a Coffee

*Memos:

MovingMNIST() can use Moving MNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:None-Type:str): *Memos:
    • None, "train" or "test" can be set to it.
    • If it's None, all 20 frames(images) of each video are returned, ignoring split_ratio.
  • The 3rd argument is split_ratio(Optional-Default:10-Type:int): *Memos:
    • If split is "train", data[:, :split_ratio] is returned.
    • If split is "test", data[:, split_ratio:] is returned.
    • If split is None, it's ignored. ignoring split_ratio.
  • The 4th argument is transform(Optional-Default:None-Type:callable).
  • The 5th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded, nothing happens.
    • It should be False if the dataset is already downloaded because it's faster.
    • You can manually download and extract the dataset(mnist_test_seq.npy) from here to data/MovingMNIST/.
from torchvision.datasets import MovingMNIST all_data = MovingMNIST( root="data" ) all_data = MovingMNIST( root="data", split=None, split_ratio=10, download=False, transform=None ) train_data = MovingMNIST( root="data", split="train" ) test_data = MovingMNIST( root="data", split="test" ) len(all_data), len(train_data), len(test_data) # (10000, 10000, 10000)  len(all_data[0]), len(train_data[0]), len(test_data[0]) # (20, 10, 10)  all_data # Dataset MovingMNIST # Number of datapoints: 10000 # Root location: data  all_data.root # 'data'  print(all_data.split) # None  all_data.split_ratio # 10  all_data.download # <bound method MovingMNIST.download of Dataset MovingMNIST # Number of datapoints: 10000 # Root location: data>  print(all_data.transform) # None  all_data[0] # tensor([[[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]], # ... # [[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)  all_data[1] # tensor([[[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]], # ... # [[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)  all_data[2] # tensor([[[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]], # ... # [[[0, 0, 0, ..., 0, 0, 0], # ..., # [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)  import matplotlib.pyplot as plt def show_images(data, labs): plt.figure(figsize=(8, 4)) for i, (vid, lab) in enumerate(iterable=zip(data, labs), start=1): plt.subplot(1, 3, i) plt.imshow(X=vid.squeeze()[0]) plt.title(label=lab) plt.tight_layout() plt.show() videos = (all_data[0], train_data[0], test_data[0]) titles = ("all_data", "train_data", "test_data") show_images(data=videos, labs=titles) 
Enter fullscreen mode Exit fullscreen mode

Image description

from torchvision.datasets import MovingMNIST all_data = MovingMNIST( root="data", split=None ) train_data = MovingMNIST( root="data", split="train" ) test_data = MovingMNIST( root="data", split="test" ) import matplotlib.pyplot as plt def show_images(data, main_title=None): plt.figure(figsize=(12, 10)) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, im in enumerate(iterable=data.squeeze(), start=1): plt.subplot(4, 5, i) plt.title(label=i) plt.imshow(X=im) plt.tight_layout() plt.show() show_images(data=all_data[0], main_title="all_data") show_images(data=train_data[0], main_title="train_data") show_images(data=test_data[0], main_title="test_data") 
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

from torchvision.datasets import MovingMNIST all_data = MovingMNIST( root="data", split=None ) train_data = MovingMNIST( root="data", split="train" ) test_data = MovingMNIST( root="data", split="test" ) import matplotlib.pyplot as plt def show_images(data, main_title=None): plt.figure(figsize=(10, 8)) plt.suptitle(t=main_title, y=1.0, fontsize=14) for i, vid in zip(range(1, 6), data): plt.subplot(4, 5, i) plt.title(label=i) plt.imshow(X=vid.squeeze()[0]) plt.tight_layout() plt.show() show_images(data=all_data, main_title="all_data") show_images(data=train_data, main_title="train_data") show_images(data=test_data, main_title="test_data") 
Enter fullscreen mode Exit fullscreen mode

Image description

from torchvision.datasets import MovingMNIST import matplotlib.animation as animation all_data = MovingMNIST( root="data" ) import matplotlib.pyplot as plt from IPython.display import HTML figure, axis = plt.subplots() # ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `ArtistAnimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ims = [] for im in all_data[0].squeeze(): ims.append([axis.imshow(X=im)]) ani = animation.ArtistAnimation(fig=figure, artists=ims, interval=100) # ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `ArtistAnimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑  # ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `FuncAnimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ # def animate(i): # axis.imshow(X=all_data[0].squeeze()[i]) # # ani = animation.FuncAnimation(fig=figure, func=animate, # frames=20, interval=100) # ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `FuncAnimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑  # ani.save('result.gif') # Save the animation as a `.gif` file  plt.ioff() # Hide a useless image  # ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ Show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ HTML(ani.to_jshtml()) # Animation operator # HTML(ani.to_html5_video()) # Animation video # ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑  # ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ Show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ # plt.rcParams["animation.html"] = "jshtml" # Animation operator # plt.rcParams["animation.html"] = "html5" # Animation video # ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ 
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

from torchvision.datasets import MovingMNIST from ipywidgets import interact, IntSlider all_data = MovingMNIST( root="data" ) import matplotlib.pyplot as plt from IPython.display import HTML def func(i): plt.imshow(X=all_data[0].squeeze()[i]) interact(func, i=(0, 19, 1)) # interact(func, i=IntSlider(min=0, max=19, step=1, value=0)) # ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Set the start value ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ plt.show() 
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Top comments (0)