{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classifier for the MNIST dataset using TensorFlow (TF.Keras)\n", "This is a \"Hello world!\" example with TensorFlow. \n", "It trains a convolutional neural network on the MNIST dataset. \n", "\n", "Credits: \n", "MNIST dataset, see http://yann.lecun.com/exdb/mnist/ \n", "Code adapted from the documentation by the TensorFlow and Keras team https://www.tensorflow.org/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "tf.__version__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This notebook is intended to be run with GPU resources\n", "tf.config.list_physical_devices('GPU')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# data for mnist\n", "mnist = tf.keras.datasets.mnist\n", "\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "\n", "x_train = x_train / 255.0\n", "x_test = x_test / 255.0\n", "\n", "x_train = x_train.reshape(60000,28,28,1)\n", "x_test = x_test.reshape(10000,28,28,1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This model uses convolutional neural network layers\n", "\n", "model = tf.keras.Sequential([\n", " tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28,28,1)),\n", " tf.keras.layers.MaxPool2D((2, 2)),\n", " tf.keras.layers.Conv2D(16, 3, activation='relu'),\n", " tf.keras.layers.MaxPool2D((2, 2)),\n", " tf.keras.layers.Dropout(0.25),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", " tf.keras.layers.Dense(10, activation='softmax'),\n", "])\n", "\n", "model.summary()\n", "\n", "optimizer = tf.keras.optimizers.Adam()\n", "loss = tf.keras.losses.sparse_categorical_crossentropy\n", "\n", "model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0204 - accuracy: 0.9931 - val_loss: 0.0217 - val_accuracy: 0.9930\n", "Epoch 2/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0187 - accuracy: 0.9939 - val_loss: 0.0250 - val_accuracy: 0.9922\n", "Epoch 3/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0189 - accuracy: 0.9939 - val_loss: 0.0226 - val_accuracy: 0.9925\n", "Epoch 4/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0175 - accuracy: 0.9942 - val_loss: 0.0268 - val_accuracy: 0.9918\n", "Epoch 5/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0182 - accuracy: 0.9936 - val_loss: 0.0224 - val_accuracy: 0.9929\n", "Epoch 6/10\n", "938/938 [==============================] - 6s 6ms/step - loss: 0.0159 - accuracy: 0.9946 - val_loss: 0.0243 - val_accuracy: 0.9925\n", "Epoch 7/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0155 - accuracy: 0.9945 - val_loss: 0.0267 - val_accuracy: 0.9929\n", "Epoch 8/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0138 - accuracy: 0.9955 - val_loss: 0.0243 - val_accuracy: 0.9923\n", "Epoch 9/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0140 - accuracy: 0.9952 - val_loss: 0.0252 - val_accuracy: 0.9935\n", "Epoch 10/10\n", "938/938 [==============================] - 5s 5ms/step - loss: 0.0138 - accuracy: 0.9954 - val_loss: 0.0263 - val_accuracy: 0.9927\n" ] } ], "source": [ "# Train the model on the training data\n", "\n", "history = model.fit(x_train, y_train, batch_size=64, \n", " validation_data=(x_test, y_test),\n", " epochs=10)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Trained model performance\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "313/313 [==============================] - 1s 3ms/step - loss: 0.0263 - accuracy: 0.9927\n", "The accuracy on the test dataset is 0.9927\n" ] } ], "source": [ "# Compute the loss and accuracy on the test data\n", "\n", "test_loss, test_accuracy = model.evaluate(x_test, y_test)\n", "\n", "print(f\"The accuracy on the test dataset is {round(test_accuracy,4)}\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Zoom in on a test case" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 0s 216ms/step\n" ] }, { "data": { "text/plain": [ "array([[5.1579474e-12, 2.2326599e-08, 4.4545445e-08, 1.8066143e-09,\n", " 2.6469557e-10, 9.0394258e-14, 1.3470490e-15, 9.9999988e-01,\n", " 1.0610262e-11, 8.4223997e-09]], dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use the model to predict the probabilities of each digit from the first test sample\n", "\n", "model.predict(x_test[0].reshape(1,28,28,1))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 0s 45ms/step\n" ] }, { "data": { "text/plain": [ "7" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Find the most likely digit from the model prediction\n", "\n", "np.argmax(model.predict(x_test[0].reshape(1,28,28,1)))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAANh0lEQVR4nO3df6zddX3H8dfL/sJeYFKwtSuVKqKxOsHlCppuSw3DAYYUo2w0GekSZskGCSxmG2ExkmxxjIiETWdSR2clCFOBQLRzksaNkLHKhZRSKFuRdVh71wvUrUXgtqXv/XG/LJdyz+dezvd7zve07+cjuTnnfN/ne77vfHtf/X7v+XzP+TgiBODY95a2GwDQH4QdSIKwA0kQdiAJwg4kMbufG5vreXGchvq5SSCVV/QLHYhxT1WrFXbb50u6RdIsSX8XETeUnn+chnSOz62zSQAFm2NTx1rXp/G2Z0n6qqQLJC2XtNr28m5fD0Bv1fmb/WxJT0fEMxFxQNKdklY10xaAptUJ+xJJP530eFe17HVsr7U9YnvkoMZrbA5AHXXCPtWbAG+49jYi1kXEcEQMz9G8GpsDUEedsO+StHTS41Ml7a7XDoBeqRP2hyWdYftdtudKulTSfc20BaBpXQ+9RcQh21dJ+idNDL2tj4gnGusMQKNqjbNHxEZJGxvqBUAPcbkskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlaUzbb3ilpv6RXJR2KiOEmmgLQvFphr3w8Ip5v4HUA9BCn8UASdcMekn5o+xHba6d6gu21tkdsjxzUeM3NAehW3dP4FRGx2/ZCSffbfioiHpj8hIhYJ2mdJJ3oBVFzewC6VOvIHhG7q9sxSfdIOruJpgA0r+uw2x6yfcJr9yV9QtK2phoD0Kw6p/GLJN1j+7XX+VZE/KCRrgA0ruuwR8Qzks5ssBcAPcTQG5AEYQeSIOxAEoQdSIKwA0k08UGYFF747Mc61t552dPFdZ8aW1SsHxifU6wvuaNcn7/rxY61w1ueLK6LPDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPP0J/88bc61j499PPyyqfX3PjKcnnnoZc61m557uM1N370+vHYaR1rQzf9UnHd2Zseabqd1nFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG/SVpO9II4x+f2bXtN+sVnzulYe/5D5f8zT9pe3sc/f7+L9bkf+p9i/cYP3t2xdt5bXy6u+/2Xji/WPzm/82fl63o5DhTrm8eHivWVxx3setvv+f4Vxfp71z7c9Wu3aXNs0r7YO+UvFEd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCz7PP0NB3Nxdq9V77xHqr62/esbJj7S9WLCtv+1/K33l/48r3dNHRzMx++XCxPrR1tFg/+YG7ivVfmdv5+/bn7yx/F/+xaNoju+31tsdsb5u0bIHt+23vqG5P6m2bAOqayWn8NySdf8SyayVtiogzJG2qHgMYYNOGPSIekLT3iMWrJG2o7m+QdHGzbQFoWrdv0C2KiFFJqm4Xdnqi7bW2R2yPHNR4l5sDUFfP342PiHURMRwRw3M0r9ebA9BBt2HfY3uxJFW3Y821BKAXug37fZLWVPfXSLq3mXYA9Mq04+y279DEN5efYnuXpC9IukHSt21fLulZSZf0skmUHfrvPR1rQ3d1rknSq9O89tB3X+iio2bs+f2PFesfmFv+9f3S3vd1rC37+2eK6x4qVo9O04Y9IlZ3KB2d30IBJMXlskAShB1IgrADSRB2IAnCDiTBR1zRmtmnLS3Wv3LdV4r1OZ5VrH/nlt/sWDt59KHiuscijuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GjNU3+0pFj/yLzyVNZPHChPR73gyZfedE/HMo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zoqfFPfqRj7dHP3DzN2uUZhP7g6quL9bf+64+nef1cOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Onnr2g8/HkeJfH0Vf/53nF+vwfPFasR7Gaz7RHdtvrbY/Z3jZp2fW2f2Z7S/VzYW/bBFDXTE7jvyHp/CmW3xwRZ1U/G5ttC0DTpg17RDwgaW8fegHQQ3XeoLvK9tbqNP+kTk+yvdb2iO2RgxqvsTkAdXQb9q9JOl3SWZJGJd3U6YkRsS4ihiNieM40H2wA0DtdhT0i9kTEqxFxWNLXJZ3dbFsAmtZV2G0vnvTwU5K2dXougMEw7Ti77TskrZR0iu1dkr4gaaXtszQxlLlT0hW9axGD7C0nnFCsX/brD3as7Tv8SnHdsS++u1ifN/5wsY7XmzbsEbF6isW39qAXAD3E5bJAEoQdSIKwA0kQdiAJwg4kwUdcUcuO6z9QrH/vlL/tWFu149PFdedtZGitSRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlR9L+/+9Fifevv/HWx/pNDBzvWXvyrU4vrztNosY43hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtys5f8crF+zef/oVif5/Kv0KWPXdax9vZ/5PPq/cSRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9GOfZ5X/iM7+3q1i/5PgXivXb9y8s1hd9vvPx5HBxTTRt2iO77aW2f2R7u+0nbF9dLV9g+37bO6rbk3rfLoBuzeQ0/pCkz0XE+yV9VNKVtpdLulbSpog4Q9Km6jGAATVt2CNiNCIere7vl7Rd0hJJqyRtqJ62QdLFPeoRQAPe1Bt0tpdJ+rCkzZIWRcSoNPEfgqQp/3izvdb2iO2Rgxqv2S6Abs047LaPl3SXpGsiYt9M14uIdRExHBHDczSvmx4BNGBGYbc9RxNBvz0i7q4W77G9uKovljTWmxYBNGHaoTfblnSrpO0R8eVJpfskrZF0Q3V7b086RD1nvq9Y/vOFt9V6+a9+8ZJi/W2PPVTr9dGcmYyzr5B0maTHbW+pll2niZB/2/blkp6VVP5XB9CqacMeEQ9Kcofyuc22A6BXuFwWSIKwA0kQdiAJwg4kQdiBJPiI6zFg1vL3dqytvbPe5Q/L119ZrC+77d9qvT76hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsx4Kk/7PzFvhfNn/GXCk3p1H8+UH5CRK3XR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwq8ctHZxfqmi24qVOc32wyOWhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJmczPvlTSNyW9Q9JhSesi4hbb10v6rKTnqqdeFxEbe9VoZrtXzCrW3zm7+7H02/cvLNbn7Ct/np1Psx89ZnJRzSFJn4uIR22fIOkR2/dXtZsj4ku9aw9AU2YyP/uopNHq/n7b2yUt6XVjAJr1pv5mt71M0oclba4WXWV7q+31tqf8biTba22P2B45qPF63QLo2ozDbvt4SXdJuiYi9kn6mqTTJZ2liSP/lBdoR8S6iBiOiOE5mle/YwBdmVHYbc/RRNBvj4i7JSki9kTEqxFxWNLXJZU/rQGgVdOG3bYl3Sppe0R8edLyxZOe9ilJ25pvD0BTZvJu/ApJl0l63PaWatl1klbbPksToy87JV3Rg/5Q01++sLxYf+i3lhXrMfp4g92gTTN5N/5BSZ6ixJg6cBThCjogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Trl7ohfEOT63b9sDstkcm7Qv9k41VM6RHciCsANJEHYgCcIOJEHYgSQIO5AEYQeS6Os4u+3nJP3XpEWnSHq+bw28OYPa26D2JdFbt5rs7bSIePtUhb6G/Q0bt0ciYri1BgoGtbdB7Uuit271qzdO44EkCDuQRNthX9fy9ksGtbdB7Uuit271pbdW/2YH0D9tH9kB9AlhB5JoJey2z7f977aftn1tGz10Ynun7cdtb7E90nIv622P2d42adkC2/fb3lHdTjnHXku9XW/7Z9W+22L7wpZ6W2r7R7a3237C9tXV8lb3XaGvvuy3vv/NbnuWpP+QdJ6kXZIelrQ6Ip7sayMd2N4paTgiWr8Aw/ZvSHpR0jcj4oPVshsl7Y2IG6r/KE+KiD8dkN6ul/Ri29N4V7MVLZ48zbikiyX9nlrcd4W+flt92G9tHNnPlvR0RDwTEQck3SlpVQt9DLyIeEDS3iMWr5K0obq/QRO/LH3XobeBEBGjEfFodX+/pNemGW913xX66os2wr5E0k8nPd6lwZrvPST90PYjtte23cwUFkXEqDTxyyNpYcv9HGnaabz76Yhpxgdm33Uz/XldbYR9qu/HGqTxvxUR8auSLpB0ZXW6ipmZ0TTe/TLFNOMDodvpz+tqI+y7JC2d9PhUSbtb6GNKEbG7uh2TdI8GbyrqPa/NoFvdjrXcz/8bpGm8p5pmXAOw79qc/ryNsD8s6Qzb77I9V9Klku5roY83sD1UvXEi20OSPqHBm4r6PklrqvtrJN3bYi+vMyjTeHeaZlwt77vWpz+PiL7/SLpQE+/I/0TSn7XRQ4e+3i3psernibZ7k3SHJk7rDmrijOhySSdL2iRpR3W7YIB6u03S45K2aiJYi1vq7dc08afhVklbqp8L2953hb76st+4XBZIgivogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNGNvRI2D7VDgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Visual inspection on the image\n", "\n", "plt.figure()\n", "plt.imshow(x_test[0].reshape(28,28));" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Image label value\n", "\n", "y_test[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 2 }