Skip to content

MNIST Handwritten Digit Recognition model built with only Python and Numpy for my Artificial Intelligence course

Notifications You must be signed in to change notification settings

davidandw190/mnist_from_scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST Digit Recognition from Scratch: A Pure NumPy Implementation

Introduction

This project presents a from-scratch implementation of a neural network for the MNIST digit recognition task, deliberately avoiding mainstream deep learning frameworks. The implementation relies solely on NumPy for numerical computations, emphasizing fundamental neural network concepts and providing insights into the internal mechanisms of deep learning systems.

Motivation

The motivation behind this project stems from several key objectives:

  1. Deep Understanding: By implementing a neural network without high-level abstractions provided by frameworks like PyTorch or TensorFlow, we gain intimate knowledge of the underlying mathematical operations and architectural components.

  2. Educational Value: This implementation serves as an educational tool, demonstrating how modern deep learning systems work at their core, from forward propagation to backpropagation and optimization techniques.

  3. Performance Optimization: Working directly with NumPy forces us to consider computational efficiency and memory management, leading to a deeper appreciation of optimization techniques in deep learning.

Key Features

Advanced Features and Development Workflow Experiment-Based Architecture Our implementation adopts a sophisticated experiment-based workflow that prioritizes reproducibility, traceability, and thorough analysis of results. Each training run is treated as a distinct experiment, ensuring comprehensive documentation of configurations, metrics, and outcomes. Configuration Management System The system implements a hierarchical configuration management approach that captures all aspects of model training and evaluation. Each experiment stores its configuration in a structured JSON format:

{
  "model": {
    "input_size": 784,
    "hidden_size": 512,
    "output_size": 10,
    "dropout_rate": 0.3,
    "l2_lambda": 0.0001
  },
  "training": {
    "epochs": 20,
    "initial_lr": 0.1,
    "batch_size": 128
  }
}

Experiment Directory Structure

Each training run automatically generates a timestamped experiment directory that serves as a comprehensive record of the training process. The directory structure includes:

  • Configuration files capturing all parameters
  • Training history and performance metrics
  • Model checkpoints at various stages
  • Evaluation reports
  • Generated visualizations and plots

Interactive Testing

For ease of use, we implemented an interactive prediction visualization system that enables detailed examination of the model's behavior on individual test cases. Users can select specific test images and receive comprehensive feedback including:

  • The original image display
  • Predicted digit with confidence score
  • Probability distribution across all digits
  • Clear visualization of correct vs incorrect predictions

Architecture Overview

Neural Network Structure

Our neural network implements a carefully designed feed-forward architecture optimized for the MNIST classification task. The network consists of three main layers: an input layer of 784 neurons (corresponding to the flattened 28x28 pixel images), a hidden layer of 512 neurons, and an output layer of 10 neurons (one for each digit class). The input layer normalizes the pixel values to the [0,1] range, improving gradient flow and training stability. The hidden layer employs ReLU activation functions, chosen for their ability to mitigate vanishing gradient problems while promoting sparse activations and computational efficiency. The output layer utilizes softmax activation to produce probability distributions over the possible digit classes.

The network implements a feed-forward architecture with the following key components:

  • Input Layer: 784 neurons (28x28 flattened MNIST images)
  • Hidden Layer: 512 neurons with ReLU activation
  • Output Layer: 10 neurons with softmax activation for digit classification
  • Dropout regularization (rate: 0.3)
  • L2 regularization (lambda: 0.0001)

Regularization Framework

Our implementation incorporates a comprehensive regularization strategy to prevent overfitting and improve generalization. The primary components include:

  1. Dropout regularization with a rate of 0.3, dynamically masking different neuron combinations during training
  2. L2 weight regularization (lambda = 0.0001) to prevent excessive weight growth
  3. Batch normalization between layers to reduce internal covariate shift and accelerate training

Project Structure

The implementation follows a modular architecture with clear separation of concerns:

Main Module** (main.py)

The main module orchestrates the entire training and evaluation process. It handles configuration management, experiment tracking, and coordinates the interaction between other components. The module creates timestamped experiment directories for each training run, ensuring reproducibility and proper organization of results.

Neural Network Module (nn.py)

The core neural network implementation resides in this module. It contains the complete feed-forward neural network architecture, including the sophisticated backpropagation mechanism and various optimization techniques. The module implements batch normalization, dropout regularization, and advanced weight initialization strategies.

Data Processing Module (data.py)

This module contains the MNISTDataLoader class, which handles data preprocessing, augmentation, and batch generation. It implements sophisticated data normalization techniques and provides utilities for dataset splitting and validation set creation.

Metrics Tracking Module (metrics.py)

The metrics module maintains comprehensive performance statistics throughout training. It implements various evaluation metrics, including accuracy calculation, loss tracking, and confusion matrix generation.

Visualization Module (visualizer.py)

This component handles all visualization aspects, from training progress plots to individual prediction visualization. It provides interactive visualization capabilities for model inspection and result analysis.

Training Progression Analysis

The network exhibited a clear three-phase learning pattern during its 20-epoch training period. Each phase demonstrated distinct characteristics that provide insights into the learning process. The network achieved impressive results across all dataset splits:

Dataset Performance Overview

Dataset Type Sample Size Accuracy Loss Key Characteristics
Training 40,800 98.27% 0.7871 • Smooth learning curve
• Consistent improvements
• Optimal convergence
Validation 10,200 97.50% - • Stable epoch performance
• Minimal accuracy gap
• Strong generalization
Test 9,000 97.00% - • Consistent metrics
• Robust performance
• Strong class-wise results

Performance Progression by Training Phase

Training Phase Epochs Starting Accuracy Final Accuracy Loss Reduction
Initial Learning 1-5 86.00% 95.35% 5.9105 → 2.0673
Refinement 6-13 95.73% 97.45% 1.8587 → 1.1127
Convergence 14-20 97.65% 98.27% 1.0470 → 0.7871

Performance Analysis

Top Performing Digits

Digit Precision Recall F1-Score Key Characteristics Common Confusions
0 98% 100% 0.99 • Nearly perfect recognition
• Strongest overall performer
• Highly distinct features
Minimal confusion with any digits
1 98% 98% 0.98 • Exceptional consistency
• Strong feature isolation
• Reliable detection
Very rare confusion with 7

Most Challenging Digits

Digit Precision Recall F1-Score Primary Challenges Main Confusions
9 95% 96% 0.95 • Curved segment variations
• Complex upper loop
• Variable writing styles
• Digit 4 (upper loop)
• Digit 7 (angled stroke)
8 96% 97% 0.96 • Double loop structure
• Style variations
• Complex geometry
• Digit 3 (curved segments)
• Digit 5 (lower loop)

Error Pattern Analysis

Error Type Primary Digits Affected Root Cause Impact on Accuracy
Visual Similarity 3↔8, 4↔9 Similar structural elements 2-3% accuracy reduction
Loop Structure 6,8,9 Complex curved segments 1-2% accuracy reduction
Linear Structure 1,7 Angle variations <1% accuracy reduction

These tables show the clear performance distinctions between easily recognized digits (0,1) and more challenging ones (8,9). The analysis reveals that structural complexity and visual similarity are the main factors affecting recognition accuracy. The exceptional performance on digits 0 and 1 can be attributed to their distinct and less variable structural features.

About

MNIST Handwritten Digit Recognition model built with only Python and Numpy for my Artificial Intelligence course

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages