Exercise: Handwritten Digit Classification using ANN (MNIST Dataset)
1. Objective
To build, train, and evaluate a feedforward artificial neural network (ANN) that classifies
handwritten digits from the MNIST dataset using both manual training loop (GradientTape) and
Keras’s high-level API.
2. Tools Required
Python 3.x
TensorFlow
NumPy
Matplotlib
3. Dataset Description
The MNIST dataset contains 70,000 grayscale images of handwritten digits (0 to 9), each of size
28x28 pixels. It is divided into:
Training Set: 60,000 images
Test Set: 10,000 images
4. Summary of Key Concepts
Concept Description Role in Project
MNIST Dataset Handwritten digit images and labels Provides input images and
expected outputs
ANN Fully connected neural network Learns patterns in image data to
classify digits
Flatten Layer Reshapes 28x28 to 784 vector Prepares image data for dense
layers
Dense Layer Fully connected neural layer Learns features through weighted
connections
ReLU Activation Applies ReLU non-linearity Allows network to learn complex
functions
Loss Function Measures difference between Guides learning by minimizing
(CrossEntropy) predicted and actual labels classification error
Optimizer (Adam) Optimizes weights using gradients Adjusts model weights during
training
GradientTape Manual training method Records operations for
backpropagation
Epoch One full pass over training data Repeated passes help refine
learning
Accuracy Performance metric in classification Measures correct predictions on
test set
Softmax Layer Converts logits to probabilities Used during prediction for
interpretation
5. Model Building Steps
Step 1: Import Libraries
Explanation: These libraries are needed for building the ANN, processing data, and
visualization.
Question: Why do we import models from keras?
A: To use the Sequential model for stacking layers.
Question. What is the purpose of tensorflow.keras in this code?
Answer:
tensorflow.keras is a high-level API that allows us to build, train, and evaluate deep learning
models easily. It includes layers, optimizers, and tools for loading datasets like MNIST.
Step 2: Load and Normalize the Data
Explanation: Pixel values are scaled from [0, 255] to [0, 1] for faster learning.
Question: What is normalization?
Answer: Scaling features to a standard range, here [0, 1].
Question. Why do we divide the pixel values by 255?
Answer:
Pixel values range from 0 to 255. Dividing by 255 normalizes the values to a range of 0 to 1,
which speeds up training and helps the model learn better.
Step 3: Create Training Batches
Explanation: Batches help in efficient training. Shuffling ensures varied input order per epoch.
Question: Why use batching?
Answer: For computational efficiency and stable gradient estimates.
Step 4: Build the Neural Network Model
Explanation: Sequential layers stack transformations on the input to produce logits.
Question. What is the purpose of the Flatten layer?
Answer:
The Flatten layer converts the 2D image (28x28) into a 1D vector (784) so that it can be passed
to the Dense layers.
Question. Why is the last Dense layer's output 10?
Answer:
Because we have 10 digit classes (0 to 9), we need 10 output neurons to represent the probability
for each class.
Question: Why no softmax in the last layer?
Answer: We’ll use logits with from_logits=True in loss function.
Step 5: Define Loss and Optimizer
Explanation: Cross-entropy is suitable for classification; Adam adapts learning rates.
Question: Why SparseCategoricalCrossentropy?
Answer: Because labels are integers (not one-hot encoded).
Question. What does an optimizer do during training?
Answer:
The optimizer updates the model's weights using the gradients to reduce the loss and improve
accuracy.
Section A: Manual Training Using GradientTape
Step 6: Manual Training Loop
Explanation: Custom loop for educational clarity. Shows step-by-step learning and loss updates.
Question: What does GradientTape() do?
Answer: Records operations to compute gradients.
OUTPUT:
What Do These Values Indicate?
The loss is decreasing steadily with each epoch:
Epoch 1 (0.2277) → relatively high, as the model starts with random weights.
Epoch 5 (0.0332) → much lower, indicating the model has learned meaningful
patterns from the training data.
This suggests:
Your model is training correctly
The optimizer is working
Gradient descent is minimizing the loss
The ANN is learning useful representations of the data
NOTE: After every epoch, the model:
1. Makes predictions.
2. Compares predictions with actual labels.
3. Computes loss (error).
4. Adjusts weights using gradients to minimize the error.
This process, called backpropagation, helps the model improve its accuracy over time.
Summary Table of Questions
Code Line Question Answer
One complete pass over the entire training
epochs = 5 What is an epoch?
dataset.
Why reset total_loss each To calculate fresh average loss for the new
total_loss = 0
epoch? epoch.
(x_batch, A subset of training data processed in one
What is a batch?
y_batch) step.
GradientTape() What does it do? Tracks operations to compute gradients.
training=True Why set this flag? To enable training behaviors like dropout.
What does the loss function
loss_fn() Measures prediction error.
do?
tape.gradient() What are gradients? Slopes that guide weight updates.
To improve model accuracy by updating
apply_gradients() Why apply gradients?
weights.
loss.numpy() Why convert loss to NumPy? To use it in Python arithmetic.
len(train_dataset) Why divide by this? To calculate average loss across all batches.
Step 7: Evaluate the Model Accuracy
Step 8: Make Predictions:
We compute accuracy by comparing predicted and true labels. argmax picks the class with the
highest logit value.
Question: Why set training=False during evaluation?
Answer: To disable layers like dropout or batch normalization.
Question: How is accuracy calculated?
Answer: Number of correct predictions divided by total test samples.
Question: Why use argmax?
Answer: To select the predicted class with highest score.
Step 9: Evaluate Individual Predictions:
Question: What does argmax(logits, axis=1) do?
Answer: Picks the index of the highest score (predicted class).
Question: Why use plt.imshow(images[i], cmap='gray')?
Answer: To display grayscale MNIST images.
Section B: Training Using High-Level Keras API
Complie the Model
Explanation: Short and readable method using Keras’s high-level API.
Question: When is model.fit() preferred?
Answer: When you don’t need custom training steps.
6. Visualizing Predictions
Explanation: Useful for visual verification of model predictions.
7. Real-world Applications
Digit recognition on postal codes (OCR)
Bank cheque processing
Touchscreen handwriting input
8. Conclusion
This exercise demonstrated building and training a simple ANN for digit classification using
both a manual and high-level API.