PyTorch_Lightning_MNIST.ipynb Open in SWAN Download

MNIST Pytorch Lightning

This is a "Hello world!" example with PyTorch Lightning.
It trains a convolutional neural network on the MNIST dataset.

Credits:
MNIST dataset, see http://yann.lecun.com/exdb/mnist/
Code adapted from the documentation of the Lightning-AI and PyTorch projects

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
In [2]:
# Define the model and the training and test steps
# The model uses convolutional neural network layers

class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        accuracy = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_accuracy', accuracy)
In [4]:
# Code for reading the data, training the model

num_epochs = 10
num_workers = 4

torch.manual_seed(1)
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST('../data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)

model = LitMNIST()
trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader, val_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | conv1    | Conv2d           | 320   
1 | conv2    | Conv2d           | 18.5 K
2 | dropout1 | Dropout          | 0     
3 | dropout2 | Dropout          | 0     
4 | fc1      | Linear           | 1.2 M 
5 | fc2      | Linear           | 1.3 K 
6 | loss_fn  | CrossEntropyLoss | 0     
----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
`Trainer.fit` stopped: `max_epochs=10` reached.
In [5]:
# Evaluate the trained model on the validation dataset

trainer.validate(model, val_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Validation: 0it [00:00, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       val_accuracy            0.9901999831199646     │
│         val_loss              0.03850765526294708    │
└───────────────────────────┴───────────────────────────┘
Out[5]:
[{'val_loss': 0.03850765526294708, 'val_accuracy': 0.9901999831199646}]
In [ ]: