{ "cells": [ { "cell_type": "markdown", "id": "f3c939e6", "metadata": {}, "source": [ "# MNIST Pytorch Lightning\n", "This is a \"Hello world!\" example with PyTorch Lightning. \n", "It trains a convolutional neural network on the MNIST dataset. \n", "\n", "Credits: \n", "MNIST dataset, see http://yann.lecun.com/exdb/mnist/ \n", "Code adapted from the documentation of the Lightning-AI and PyTorch projects\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "bf6f202b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "import pytorch_lightning as pl" ] }, { "cell_type": "code", "execution_count": 2, "id": "d623966a", "metadata": {}, "outputs": [], "source": [ "# Define the model and the training and test steps\n", "# The model uses convolutional neural network layers\n", "\n", "class LitMNIST(pl.LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", " self.dropout1 = nn.Dropout(0.25)\n", " self.dropout2 = nn.Dropout(0.5)\n", " self.fc1 = nn.Linear(9216, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", " self.loss_fn = nn.CrossEntropyLoss()\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.dropout2(x)\n", " x = self.fc2(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", " \n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-3)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = self.loss_fn(y_hat, y)\n", " self.log('train_loss', loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = self.loss_fn(y_hat, y)\n", " preds = torch.argmax(y_hat, dim=1)\n", " accuracy = (preds == y).float().mean()\n", " self.log('val_loss', loss)\n", " self.log('val_accuracy', accuracy)\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "4159c154", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------\n", "0 | conv1 | Conv2d | 320 \n", "1 | conv2 | Conv2d | 18.5 K\n", "2 | dropout1 | Dropout | 0 \n", "3 | dropout2 | Dropout | 0 \n", "4 | fc1 | Linear | 1.2 M \n", "5 | fc2 | Linear | 1.3 K \n", "6 | loss_fn | CrossEntropyLoss | 0 \n", "----------------------------------------------\n", "1.2 M Trainable params\n", "0 Non-trainable params\n", "1.2 M Total params\n", "4.800 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3c01ff29167e4831aad4aa8768d06776", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=10` reached.\n" ] } ], "source": [ "# Code for reading the data, training the model\n", "\n", "num_epochs = 10\n", "num_workers = 4\n", "\n", "torch.manual_seed(1)\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))])\n", "dataset = MNIST('../data', train=True, download=True, transform=transform)\n", "train_dataset, val_dataset = random_split(dataset, [55000, 5000])\n", "train_loader = DataLoader(train_dataset, batch_size=64, num_workers=num_workers)\n", "val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)\n", " \n", "model = LitMNIST()\n", "trainer = pl.Trainer(max_epochs=num_epochs)\n", "trainer.fit(model, train_loader, val_loader)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "7fb9d12c", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d18a358a2564b52b71eb0d6eab1444e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃      Validate metric             DataLoader 0        ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│       val_accuracy            0.9901999831199646     │\n",
       "│         val_loss              0.03850765526294708    │\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m val_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9901999831199646 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.03850765526294708 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[{'val_loss': 0.03850765526294708, 'val_accuracy': 0.9901999831199646}]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the trained model on the validation dataset\n", "\n", "trainer.validate(model, val_loader)" ] }, { "cell_type": "code", "execution_count": null, "id": "7e4d3a1b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }