DEV Community

Gavi Narra
Gavi Narra

Posted on

Training a Simple Neural Network in PyTorch and Integrating with Gradio for MNIST Digit Recognition

Introduction

In this article, we will walk through the steps of training a simple neural network on the MNIST dataset using PyTorch and then deploying it with Gradio for interactive predictions. The MNIST dataset is a popular dataset in the field of machine learning that consists of 70,000 28x28 grayscale images of handwritten digits.

Training a Neural Network with PyTorch

PyTorch is an open-source deep learning framework developed by Facebook's artificial intelligence research group. It provides a wide range of functionalities for building and training neural networks.

Step 1: Import necessary libraries

First, we need to import PyTorch, torchvision (a package with popular datasets, model architectures, and common image transformations), and some other necessary libraries:

import torch import torchvision import torchvision.transforms as transforms from torch import nn, optim 
Enter fullscreen mode Exit fullscreen mode

Step 2: Load the dataset

Next, we load the MNIST dataset. We'll use torchvision's built-in functionality to do this. We also apply transformations to normalize the data:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) 
Enter fullscreen mode Exit fullscreen mode

Step 3: Define the network

We'll define a simple feed-forward neural network with one hidden layer:

class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x net = Net() 
Enter fullscreen mode Exit fullscreen mode

Step 4: Define the loss function and optimizer

We'll use CrossEntropyLoss for our loss function and SGD for our optimizer:

criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) 
Enter fullscreen mode Exit fullscreen mode

Step 5: Train the network

Now we're ready to train our network:

for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f'Epoch {epoch+1}, loss: {running_loss/len(trainloader)}') print('Finished Training') 
Enter fullscreen mode Exit fullscreen mode

Deploying with Gradio

Gradio is an open-source library for creating customizable UI components around your ML models. It allows us to demonstrate a model’s functionality in an intuitive manner.

Step 1: Install Gradio

!pip install gradio 
Enter fullscreen mode Exit fullscreen mode

Step 2: Import Gradio and define the prediction function

import gradio as gr def predict(image): image = image.reshape(1, 1, 28, 28) image = torch.from_numpy(image).float() output = net(image) _, predicted = torch.max(output.data, 1) return predicted.item() 
Enter fullscreen mode Exit fullscreen mode

In the predict function, we take the input image, reshape it to match our model's expected input shape, convert it to a torch tensor, pass it through our model to get the output, and then return the predicted digit.

Step 3: Define the Gradio interface

Now, we define the interface for our model. We'll use an 'Image' input interface and a 'Label' output interface:

iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(shape=(28, 28), invert_colors=True, source="canvas"), outputs="label", interpretation="default" ) 
Enter fullscreen mode Exit fullscreen mode

The 'Image' input interface lets users draw an image with their mouse. We set invert_colors=True because the MNIST dataset consists of white digits on a black background, and by default, the Gradio image interface has a white background.

Step 4: Launch the interface

Finally, we launch the interface:

iface.launch() 
Enter fullscreen mode Exit fullscreen mode

With this, you should see an interactive interface where you can draw a digit and see the prediction from your PyTorch model.

Conclusion

In this article, we saw how to train a simple neural network using PyTorch and then deploy it with Gradio for interactive predictions. This combination allows us to leverage the power of deep learning models in an easy-to-use and interpret manner.

Top comments (0)