Introduction
In this article, we will walk through the steps of training a simple neural network on the MNIST dataset using PyTorch and then deploying it with Gradio for interactive predictions. The MNIST dataset is a popular dataset in the field of machine learning that consists of 70,000 28x28 grayscale images of handwritten digits.
Training a Neural Network with PyTorch
PyTorch is an open-source deep learning framework developed by Facebook's artificial intelligence research group. It provides a wide range of functionalities for building and training neural networks.
Step 1: Import necessary libraries
First, we need to import PyTorch, torchvision (a package with popular datasets, model architectures, and common image transformations), and some other necessary libraries:
import torch import torchvision import torchvision.transforms as transforms from torch import nn, optim
Step 2: Load the dataset
Next, we load the MNIST dataset. We'll use torchvision's built-in functionality to do this. We also apply transformations to normalize the data:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Step 3: Define the network
We'll define a simple feed-forward neural network with one hidden layer:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x net = Net()
Step 4: Define the loss function and optimizer
We'll use CrossEntropyLoss for our loss function and SGD for our optimizer:
criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
Step 5: Train the network
Now we're ready to train our network:
for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f'Epoch {epoch+1}, loss: {running_loss/len(trainloader)}') print('Finished Training')
Deploying with Gradio
Gradio is an open-source library for creating customizable UI components around your ML models. It allows us to demonstrate a model’s functionality in an intuitive manner.
Step 1: Install Gradio
!pip install gradio
Step 2: Import Gradio and define the prediction function
import gradio as gr def predict(image): image = image.reshape(1, 1, 28, 28) image = torch.from_numpy(image).float() output = net(image) _, predicted = torch.max(output.data, 1) return predicted.item()
In the predict function, we take the input image, reshape it to match our model's expected input shape, convert it to a torch tensor, pass it through our model to get the output, and then return the predicted digit.
Step 3: Define the Gradio interface
Now, we define the interface for our model. We'll use an 'Image' input interface and a 'Label' output interface:
iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(shape=(28, 28), invert_colors=True, source="canvas"), outputs="label", interpretation="default" )
The 'Image' input interface lets users draw an image with their mouse. We set invert_colors=True
because the MNIST dataset consists of white digits on a black background, and by default, the Gradio image interface has a white background.
Step 4: Launch the interface
Finally, we launch the interface:
iface.launch()
With this, you should see an interactive interface where you can draw a digit and see the prediction from your PyTorch model.
Conclusion
In this article, we saw how to train a simple neural network using PyTorch and then deploy it with Gradio for interactive predictions. This combination allows us to leverage the power of deep learning models in an easy-to-use and interpret manner.
Top comments (0)