Physics-Inspired Machine Learning: The Gravity MNIST Method

Machine Learning
MNIST
Author

Pratik Kulkarni

Published

March 10, 2025

In the world of machine learning, the MNIST dataset of handwritten digits has become something of a “Hello World” example. Everyone who’s dabbled in ML has probably trained a model to recognize these digits. But today, I want to share a creative approach that looks at this classic problem through the lens of physics.

The Gravity MNIST Method

The premise is wonderfully simple: what if we treat each pixel in a digit image as a “particle of sand”? The brighter the pixel, the more sand is present. Now, let’s apply gravity and see what happens when all that sand falls to the bottom of the image.

This approach transforms our 2D image into a 1D distribution - essentially a histogram of where the “sand” landed. But we don’t stop there. We apply gravity in all four cardinal directions (down, right, up, left) to get a more complete picture of the digit’s structure.

The real magic happens when we model each of these distributions as a mixture of Gaussian curves. The parameters of these Gaussians - their means, variances, and weights - become our features for classification.

Implementation

Here’s how it works in practice:

  1. Load the MNIST data: Each image is a grayscale 8×8 grid of pixels
  2. Apply “gravity”: Sum pixel values along columns to simulate sand falling down
  3. Repeat for all directions: Rotate and repeat to get distributions for right, up, and left
  4. Fit Gaussian mixtures: Model each distribution as a mixture of Gaussian curves
  5. Extract features: Use the parameters of these curves as features
  6. Train a classifier: Feed these features into a Random Forest classifier

Fitting the Gaussian Mixture model to the data happens like so:

import numpy as np
from sklearn.mixture import GaussianMixture
# Function to fit a Gaussian Mixture Model to the distribution
def fit_gmm(distribution, n_components=3):
    # Prepare the data for GMM (needs to be 2D array)
    X = np.column_stack((np.arange(len(distribution)), distribution))
    
    # Fit the GMM
    gmm = GaussianMixture(n_components=n_components, random_state=42)
    gmm.fit(X)
    
    return {
        'means': gmm.means_.flatten(),
        'variances': gmm.covariances_.flatten() if gmm.covariance_type == 'full' else gmm.covariances_.flatten(),
        'weights': gmm.weights_
    }

Then, we simply train a Random Forest Classifier on these values. Here’s the entire code:

import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from scipy.ndimage import rotate

# Load the MNIST dataset from scikit-learn
def load_mnist():
    print("Loading MNIST dataset...")
    digits = load_digits()
    X = digits.data / 16.0  # Normalize pixel values (max is 16 in this dataset)
    y = digits.target
    return X, y

# Function to apply gravity in a specific direction
def apply_gravity(image, direction):
    """
    Apply gravity to an image in a specific direction.
    Direction: 0 = down, 1 = right, 2 = up, 3 = left
    Returns the 1D distribution of "sand particles"
    """
    img = image.reshape(8, 8)  # scikit-learn digits are 8x8
    
    if direction == 0:  # Down
        return np.sum(img, axis=0)  # Sum along rows to get distribution at bottom
    elif direction == 1:  # Right
        return np.sum(img, axis=1)  # Sum along columns to get distribution at right
    elif direction == 2:  # Up
        return np.sum(img, axis=0)[::-1]  # Reverse to represent particles at top
    elif direction == 3:  # Left
        return np.sum(img, axis=1)[::-1]  # Reverse to represent particles at left
    else:
        raise ValueError("Invalid direction. Must be 0, 1, 2, or 3.")

# Function to fit a Gaussian Mixture Model to the distribution
def fit_gmm(distribution, n_components=3):
    """
    Fit a Gaussian Mixture Model to the 1D distribution.
    Returns the parameters (means, variances, weights)
    """
    # Prepare the data for GMM (needs to be 2D array)
    X = np.column_stack((np.arange(len(distribution)), distribution))
    
    # Add small random noise to avoid singular matrices
    X += np.random.normal(0, 1e-5, X.shape)
    
    # Try to fit with requested components
    try:
        gmm = GaussianMixture(
            n_components=n_components, 
            random_state=42,
            covariance_type='full',  # Use full covariance for stability
            reg_covar=1e-3  # Add regularization to prevent singular covariances
        )
        gmm.fit(X)
    except:
        # Fallback to simpler model if it fails
        print("Warning: Failed to fit GMM with full covariance, trying diagonal")
        gmm = GaussianMixture(
            n_components=n_components, 
            random_state=42,
            covariance_type='diag',
            reg_covar=1e-2
        )
        gmm.fit(X)
    
    # Extract and ensure variances are positive
    if gmm.covariance_type == 'full':
        variances = np.maximum(gmm.covariances_.flatten(), 1e-10)  # Ensure positive
    else:
        variances = np.maximum(gmm.covariances_.flatten(), 1e-10)  # Ensure positive
    
    return {
        'means': gmm.means_.flatten(),
        'variances': variances,
        'weights': gmm.weights_
    }

# Extract features from all directions
def extract_features(image, n_components=3):
    """
    Extract features from an image by applying gravity in all four directions
    and fitting GMMs to the resulting distributions.
    """
    features = []
    
    for direction in range(4):
        distribution = apply_gravity(image, direction)
        gmm_params = fit_gmm(distribution, n_components)
        
        # Flatten the parameters into a feature vector
        direction_features = []
        direction_features.extend(gmm_params['means'])
        direction_features.extend(gmm_params['variances'])
        direction_features.extend(gmm_params['weights'])
        
        features.extend(direction_features)
    
    return np.array(features)

# Visualize the gravity effect and GMM fit
def visualize_gravity_and_gmm(image, digit_label):
    """
    Visualize the original digit, the gravity effect in all directions,
    and the fitted GMMs.
    """
    plt.figure(figsize=(15, 10))
    
    # Plot the original digit
    plt.subplot(3, 2, 1)
    plt.imshow(image.reshape(8, 8), cmap='gray')
    plt.title(f"Original Digit: {digit_label}")
    
    directions = ['Down', 'Right', 'Up', 'Left']
    
    for i, direction in enumerate(range(4)):
        distribution = apply_gravity(image, direction)
        x = np.arange(len(distribution))
        
        # Plot the distribution
        plt.subplot(3, 2, i+2)
        plt.bar(x, distribution, alpha=0.5, color='gray')
        plt.title(f"Gravity Direction: {directions[direction]}")
        
        # Fit and plot the GMM
        gmm_params = fit_gmm(distribution)
        
        # Generate points from the fitted GMM to visualize
        for j in range(len(gmm_params['means']) // 2):  # Divide by 2 because means are 2D
            mu = gmm_params['means'][j*2]
            # Ensure variance is non-negative before taking square root
            variance = gmm_params['variances'][j*2]
            sigma = np.sqrt(max(variance, 1e-10))  # Add small epsilon to prevent negative or zero values
            weight = gmm_params['weights'][j]
            
            x_values = np.linspace(0, len(distribution), 1000)
            y_values = weight * np.exp(-(x_values - mu)**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
            
            # Scale to match the height of the distribution
            y_values = y_values * np.max(distribution) / np.max(y_values) if np.max(y_values) > 0 else y_values
            
            plt.plot(x_values, y_values, linewidth=2)
    
    plt.tight_layout()
    plt.show()

# Main function to run the whole process
def main():
    # Load MNIST data
    X, y = load_mnist()
    
    # Split into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # Sample a small subset for visualization
    sample_indices = np.random.choice(len(X_train), 5, replace=False)
    for idx in sample_indices:
        visualize_gravity_and_gmm(X_train[idx], y_train[idx])
    
    # Extract features for training data
    print("Extracting features for training data...")
    X_train_features = np.array([
        extract_features(X_train[i]) for i in range(len(X_train))
    ])
    
    # Train a classifier (Random Forest in this case)
    print("Training classifier...")
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train_features, y_train)
    
    # Extract features for test data
    print("Extracting features for test data...")
    X_test_features = np.array([
        extract_features(X_test[i]) for i in range(len(X_test))
    ])
    
    # Make predictions
    print("Making predictions...")
    y_pred = clf.predict(X_test_features)
    
    # Evaluate the model
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred))

main()
Loading MNIST dataset...

Extracting features for training data...
Training classifier...
Extracting features for test data...
Making predictions...
Accuracy: 0.9028

Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.85      0.88        33
           1       0.96      0.93      0.95        28
           2       0.91      0.88      0.89        33
           3       0.81      0.85      0.83        34
           4       0.98      1.00      0.99        46
           5       0.93      0.83      0.88        47
           6       1.00      1.00      1.00        35
           7       0.84      0.94      0.89        34
           8       0.82      0.90      0.86        30
           9       0.87      0.85      0.86        40

    accuracy                           0.90       360
   macro avg       0.90      0.90      0.90       360
weighted avg       0.90      0.90      0.90       360

Each visualization shows:

  • The original digit
  • The “sand pile” distributions for all four directions
  • The fitted Gaussian mixtures overlaid on each distribution

The surprisingly good news? This approach achieves over 90% accuracy on the test set without any complex neural networks or deep learning. That’s quite impressive for such a simple, physics-inspired method!

Why Does This Work?

The gravity method works because it captures structural information about the digits in a compact form. For example:

  • A “1” will have a concentrated pile when gravity pulls right, but a more spread-out distribution when gravity pulls down
  • An “8” will typically have two peaks when gravity pulls down (from its two loops)
  • A “7” will have a distinctive distribution when gravity pulls from the left

By capturing these characteristics from all four directions and encoding them as Gaussian parameters, we create a rich feature set that a classifier can use to distinguish between digits.

Beyond MNIST

This “gravitational feature extraction” technique could potentially be applied to other image classification problems. It’s particularly interesting for cases where:

The objects have distinctive shapes Computational resources are limited You want a more interpretable model than deep learning often provides

The physics-inspired approach also reminds us that sometimes looking at a problem through a different lens - in this case, literally imagining pixels as sand particles - can lead to creative solutions.

Conclusion

Machine learning doesn’t always have to mean complex neural networks and black-box models. Sometimes, a creative approach inspired by physical intuition can yield surprisingly good results.

The Gravity MNIST method is a perfect example of cross-disciplinary thinking: taking a concept from physics (gravity) and applying it to a classic machine learning problem. It’s a reminder that creativity still has an important place in the age of large models and big data.

Would this approach ever beat state-of-the-art deep learning models on MNIST? Probably not. But at 90% accuracy with a beautifully simple approach, it’s a wonderful example of how thinking differently can lead to elegant solutions.