Classifier for the MNIST dataset using TensorFlow (TF.Keras)¶
This is a "Hello world!" example with TensorFlow.
It trains a convolutional neural network on the MNIST dataset.
Credits:
MNIST dataset, see http://yann.lecun.com/exdb/mnist/
Code adapted from the documentation by the TensorFlow and Keras team https://www.tensorflow.org/
In [ ]:
import tensorflow as tf
tf.__version__
In [ ]:
# This notebook is intended to be run with GPU resources
tf.config.list_physical_devices('GPU')
In [3]:
# data for mnist
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.reshape(60000,28,28,1)
x_test = x_test.reshape(10000,28,28,1)
In [ ]:
# This model uses convolutional neural network layers
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28,28,1)),
tf.keras.layers.MaxPool2D((2, 2)),
tf.keras.layers.Conv2D(16, 3, activation='relu'),
tf.keras.layers.MaxPool2D((2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.summary()
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.sparse_categorical_crossentropy
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
In [6]:
# Train the model on the training data
history = model.fit(x_train, y_train, batch_size=64,
validation_data=(x_test, y_test),
epochs=10)
Trained model performance¶
In [7]:
# Compute the loss and accuracy on the test data
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"The accuracy on the test dataset is {round(test_accuracy,4)}")
In [8]:
import matplotlib.pyplot as plt
import numpy as np
Zoom in on a test case¶
In [9]:
# Use the model to predict the probabilities of each digit from the first test sample
model.predict(x_test[0].reshape(1,28,28,1))
Out[9]:
In [10]:
# Find the most likely digit from the model prediction
np.argmax(model.predict(x_test[0].reshape(1,28,28,1)))
Out[10]:
In [11]:
# Visual inspection on the image
plt.figure()
plt.imshow(x_test[0].reshape(28,28));
In [12]:
# Image label value
y_test[0]
Out[12]:
In [ ]: