MNIST Pytorch Lightning¶
This is a "Hello world!" example with PyTorch Lightning.
It trains a convolutional neural network on the MNIST dataset.
Credits:
MNIST dataset, see http://yann.lecun.com/exdb/mnist/
Code adapted from the documentation of the Lightning-AI and PyTorch projects
In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
In [2]:
# Define the model and the training and test steps
# The model uses convolutional neural network layers
class LitMNIST(pl.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
preds = torch.argmax(y_hat, dim=1)
accuracy = (preds == y).float().mean()
self.log('val_loss', loss)
self.log('val_accuracy', accuracy)
In [4]:
# Code for reading the data, training the model
num_epochs = 10
num_workers = 4
torch.manual_seed(1)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST('../data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)
model = LitMNIST()
trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader, val_loader)
In [5]:
# Evaluate the trained model on the validation dataset
trainer.validate(model, val_loader)
Out[5]:
In [ ]: