Digit Classifiication using CNN on MNIST dataset

Digit Classifiication using CNN on MNIST dataset

Sharing my knowledge on doing digit classification

Image classification is one of the most prominent technologies. It is used in powering applications like facial recognition, self-driving cars, and many more. In this blog, I will be sharing an example of Digit image classification, leveraging the power of Convolutional Neural Networks (CNNs). Our journey will revolve around a classic dataset — MNIST, which serves as the perfect playground for understanding the intricate workings of CNNs in the realm of computer vision.

Aim

The MNIST dataset is a collection of 28x28 grayscale images of handwritten digits (0-9), making it an ideal playground for beginners entering the world of computer vision and deep learning. Our goal is to build a CNN that can accurately identify and classify these digits.

Digit Classification using CNN

Key things that I will be doing in this blog

  • Dataset preparation

  • Model Architecture

  • Compilation

  • Training and Evaluation

For starters, we will first do the imports we want

import tensorflow as tf
import numpy as np
import random
import matplotlib.pyplot as plt

Dataset Preparation

# Load the dataset and repshape the data
(trainX, trainY), (testX, testY) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values
trainX = tf.keras.utils.normalize(trainX, axis=1)
testX = tf.keras.utils.normalize(testX, axis=1)

# Reshape the data to 4D tensor format (batch_size, height, width, channels)
trainX = trainX.reshape(trainX.shape[0], 28, 28, 1).astype('float32')
testX = testX.reshape(testX.shape[0], 28, 28, 1).astype('float32')

# Convert labels to one-hot encoding
trainY = tf.keras.utils.to_categorical(trainY, 10)
testY = tf.keras.utils.to_categorical(testY, 10)
  • Loads the dataset (X represents images and Y represents image labels)

  • Normalize the pixel values,

  • Reshape the data into the required 4D Tensor format (Check this docs to know about tensors)

  • Convert the labels to one-hot encoding

Model Architecture

model = tf.keras.models.Sequential()
#convolution and pooling layers
model.add(tf.keras.layers.Conv2D(32,(3,3), activation = 'relu', input_shape = (28,28,1)))

model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64,(3,3), activation = 'relu'))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64,(3,3), activation = 'relu'))
model.add(tf.keras.layers.Flatten())
# Dense layers for classification
model.add(tf.keras.layers.Dense(128, activation='relu'))

# output layer 
model.add(tf.keras.layers.Dense(10, activation='softmax'))
  • Sequential allows us to build a model, layer by layer (by stacking each layer)

  • Max pooling layers reduce the spatial dimensions of the representation, helping the network focus on the most important features.

  • Flatten layer converts the 2D matrix data to a vector, preparing it for the fully connected layers.

Compilation

model.compile(optimizer= 'adam',
             loss = 'categorical_crossentropy',
             metrics=['accuracy'])
  • categorical_crossentropy as a loss function because we are classifying multiclass problems and also one-hot encoded data

  • adam because it is one of the popular optimization algorithms, known for its adaptive learning rates and efficient handling of sparse gradients. (It determines how the model's weights are updated during training to minimize the loss function)

Training

model.fit(trainX, trainY, epochs=5, batch_size=64, validation_split=0.2)

Valuation

val_loss, val_acc = model.evaluate(testX,testY)
print("loss --> ", val_loss, "\naccuracy -->",val_acc)

Plotting Predictions

# Functions for Plotting the predictions 
def grpah_plot(indices):
    # Randomly select 4 indices from the test data
  fig, axes = plt.subplots(1, 5, figsize=(15, 5))  # Adjusted figure size
  for i, index in enumerate(indices):
      reshaped_image = testX[index].reshape(28, 28)  # Reshape the image to (28, 28)
      axes[i].imshow(reshaped_image, cmap='gray')
      axes[i].axis('off')
      axes[i].set_title(f'True: {np.argmax(testY[index])}\nPred: {np.argmax(pred[index])}')  
      # Set titles with true and predicted labels
  plt.show()

#Function for scatter plot
def scatter_plot (indices) :
  random_indices = random.sample(range(len(testX)), 50)
  # Get true labels and predicted labels for the selected indices
  true_labels = np.argmax(testY[indices], axis=1)
  predicted_labels = np.argmax(pred[indices], axis=1)

  # Create a scatter plot
  plt.figure(figsize=(10, 6))
  plt.scatter(range(len(indices)), true_labels, color='blue', label='True Labels', marker='o')
  plt.scatter(range(len(indices)), predicted_labels, color='red', label='Predicted Labels', marker='x')
  plt.xlabel('Sample Index')
  plt.ylabel('Label')
  plt.legend()
  plt.title('True vs. Predicted Labels (Randomly Selected Samples)')
  plt.show()

Graph Plot

  • Takes in a list of indices (indices) as input.

  • It then iterates through the provided indices, retrieves the corresponding images from the testX dataset, and displays them along with their true and predicted labels

Scatter Plot

  • Takes in a list of indices (indices) as input.

  • It then retrieves the true labels and predicted labels for the provided indices and creates a scatter plot.

  • The scatter plot displays true labels in blue and predicted labels in red, with markers ('o' for true labels, 'x' for predicted labels).

  • The x-axis represents the sample indices, and the y-axis represents the label values.

Conclusion

Try out the code, if you have any doubts then check this collab file: Collab

In conclusion, the exploration of digit classification on MNIST serves as a stepping stone for tackling more complex image recognition tasks and understanding the inner workings of neural networks.