import numpy as np
from sklearn.mixture import GaussianMixture
# Function to fit a Gaussian Mixture Model to the distribution
def fit_gmm(distribution, n_components=3):
# Prepare the data for GMM (needs to be 2D array)
= np.column_stack((np.arange(len(distribution)), distribution))
X
# Fit the GMM
= GaussianMixture(n_components=n_components, random_state=42)
gmm
gmm.fit(X)
return {
'means': gmm.means_.flatten(),
'variances': gmm.covariances_.flatten() if gmm.covariance_type == 'full' else gmm.covariances_.flatten(),
'weights': gmm.weights_
}
In the world of machine learning, the MNIST dataset of handwritten digits has become something of a “Hello World” example. Everyone who’s dabbled in ML has probably trained a model to recognize these digits. But today, I want to share a creative approach that looks at this classic problem through the lens of physics.
The Gravity MNIST Method
The premise is wonderfully simple: what if we treat each pixel in a digit image as a “particle of sand”? The brighter the pixel, the more sand is present. Now, let’s apply gravity and see what happens when all that sand falls to the bottom of the image.
This approach transforms our 2D image into a 1D distribution - essentially a histogram of where the “sand” landed. But we don’t stop there. We apply gravity in all four cardinal directions (down, right, up, left) to get a more complete picture of the digit’s structure.
The real magic happens when we model each of these distributions as a mixture of Gaussian curves. The parameters of these Gaussians - their means, variances, and weights - become our features for classification.
Implementation
Here’s how it works in practice:
- Load the MNIST data: Each image is a grayscale 8×8 grid of pixels
- Apply “gravity”: Sum pixel values along columns to simulate sand falling down
- Repeat for all directions: Rotate and repeat to get distributions for right, up, and left
- Fit Gaussian mixtures: Model each distribution as a mixture of Gaussian curves
- Extract features: Use the parameters of these curves as features
- Train a classifier: Feed these features into a Random Forest classifier
Fitting the Gaussian Mixture model to the data happens like so:
Then, we simply train a Random Forest Classifier on these values. Here’s the entire code:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from scipy.ndimage import rotate
# Load the MNIST dataset from scikit-learn
def load_mnist():
print("Loading MNIST dataset...")
= load_digits()
digits = digits.data / 16.0 # Normalize pixel values (max is 16 in this dataset)
X = digits.target
y return X, y
# Function to apply gravity in a specific direction
def apply_gravity(image, direction):
"""
Apply gravity to an image in a specific direction.
Direction: 0 = down, 1 = right, 2 = up, 3 = left
Returns the 1D distribution of "sand particles"
"""
= image.reshape(8, 8) # scikit-learn digits are 8x8
img
if direction == 0: # Down
return np.sum(img, axis=0) # Sum along rows to get distribution at bottom
elif direction == 1: # Right
return np.sum(img, axis=1) # Sum along columns to get distribution at right
elif direction == 2: # Up
return np.sum(img, axis=0)[::-1] # Reverse to represent particles at top
elif direction == 3: # Left
return np.sum(img, axis=1)[::-1] # Reverse to represent particles at left
else:
raise ValueError("Invalid direction. Must be 0, 1, 2, or 3.")
# Function to fit a Gaussian Mixture Model to the distribution
def fit_gmm(distribution, n_components=3):
"""
Fit a Gaussian Mixture Model to the 1D distribution.
Returns the parameters (means, variances, weights)
"""
# Prepare the data for GMM (needs to be 2D array)
= np.column_stack((np.arange(len(distribution)), distribution))
X
# Add small random noise to avoid singular matrices
+= np.random.normal(0, 1e-5, X.shape)
X
# Try to fit with requested components
try:
= GaussianMixture(
gmm =n_components,
n_components=42,
random_state='full', # Use full covariance for stability
covariance_type=1e-3 # Add regularization to prevent singular covariances
reg_covar
)
gmm.fit(X)except:
# Fallback to simpler model if it fails
print("Warning: Failed to fit GMM with full covariance, trying diagonal")
= GaussianMixture(
gmm =n_components,
n_components=42,
random_state='diag',
covariance_type=1e-2
reg_covar
)
gmm.fit(X)
# Extract and ensure variances are positive
if gmm.covariance_type == 'full':
= np.maximum(gmm.covariances_.flatten(), 1e-10) # Ensure positive
variances else:
= np.maximum(gmm.covariances_.flatten(), 1e-10) # Ensure positive
variances
return {
'means': gmm.means_.flatten(),
'variances': variances,
'weights': gmm.weights_
}
# Extract features from all directions
def extract_features(image, n_components=3):
"""
Extract features from an image by applying gravity in all four directions
and fitting GMMs to the resulting distributions.
"""
= []
features
for direction in range(4):
= apply_gravity(image, direction)
distribution = fit_gmm(distribution, n_components)
gmm_params
# Flatten the parameters into a feature vector
= []
direction_features 'means'])
direction_features.extend(gmm_params['variances'])
direction_features.extend(gmm_params['weights'])
direction_features.extend(gmm_params[
features.extend(direction_features)
return np.array(features)
# Visualize the gravity effect and GMM fit
def visualize_gravity_and_gmm(image, digit_label):
"""
Visualize the original digit, the gravity effect in all directions,
and the fitted GMMs.
"""
=(15, 10))
plt.figure(figsize
# Plot the original digit
3, 2, 1)
plt.subplot(8, 8), cmap='gray')
plt.imshow(image.reshape(f"Original Digit: {digit_label}")
plt.title(
= ['Down', 'Right', 'Up', 'Left']
directions
for i, direction in enumerate(range(4)):
= apply_gravity(image, direction)
distribution = np.arange(len(distribution))
x
# Plot the distribution
3, 2, i+2)
plt.subplot(=0.5, color='gray')
plt.bar(x, distribution, alphaf"Gravity Direction: {directions[direction]}")
plt.title(
# Fit and plot the GMM
= fit_gmm(distribution)
gmm_params
# Generate points from the fitted GMM to visualize
for j in range(len(gmm_params['means']) // 2): # Divide by 2 because means are 2D
= gmm_params['means'][j*2]
mu # Ensure variance is non-negative before taking square root
= gmm_params['variances'][j*2]
variance = np.sqrt(max(variance, 1e-10)) # Add small epsilon to prevent negative or zero values
sigma = gmm_params['weights'][j]
weight
= np.linspace(0, len(distribution), 1000)
x_values = weight * np.exp(-(x_values - mu)**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
y_values
# Scale to match the height of the distribution
= y_values * np.max(distribution) / np.max(y_values) if np.max(y_values) > 0 else y_values
y_values
=2)
plt.plot(x_values, y_values, linewidth
plt.tight_layout()
plt.show()
# Main function to run the whole process
def main():
# Load MNIST data
= load_mnist()
X, y
# Split into train and test sets
= train_test_split(
X_train, X_test, y_train, y_test =0.2, random_state=42
X, y, test_size
)
# Sample a small subset for visualization
= np.random.choice(len(X_train), 5, replace=False)
sample_indices for idx in sample_indices:
visualize_gravity_and_gmm(X_train[idx], y_train[idx])
# Extract features for training data
print("Extracting features for training data...")
= np.array([
X_train_features for i in range(len(X_train))
extract_features(X_train[i])
])
# Train a classifier (Random Forest in this case)
print("Training classifier...")
= RandomForestClassifier(n_estimators=100, random_state=42)
clf
clf.fit(X_train_features, y_train)
# Extract features for test data
print("Extracting features for test data...")
= np.array([
X_test_features for i in range(len(X_test))
extract_features(X_test[i])
])
# Make predictions
print("Making predictions...")
= clf.predict(X_test_features)
y_pred
# Evaluate the model
= accuracy_score(y_test, y_pred)
accuracy print(f"Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
main()
Loading MNIST dataset...
Extracting features for training data...
Training classifier...
Extracting features for test data...
Making predictions...
Accuracy: 0.9028
Classification Report:
precision recall f1-score support
0 0.90 0.85 0.88 33
1 0.96 0.93 0.95 28
2 0.91 0.88 0.89 33
3 0.81 0.85 0.83 34
4 0.98 1.00 0.99 46
5 0.93 0.83 0.88 47
6 1.00 1.00 1.00 35
7 0.84 0.94 0.89 34
8 0.82 0.90 0.86 30
9 0.87 0.85 0.86 40
accuracy 0.90 360
macro avg 0.90 0.90 0.90 360
weighted avg 0.90 0.90 0.90 360
Each visualization shows:
- The original digit
- The “sand pile” distributions for all four directions
- The fitted Gaussian mixtures overlaid on each distribution
The surprisingly good news? This approach achieves over 90% accuracy on the test set without any complex neural networks or deep learning. That’s quite impressive for such a simple, physics-inspired method!
Why Does This Work?
The gravity method works because it captures structural information about the digits in a compact form. For example:
- A “1” will have a concentrated pile when gravity pulls right, but a more spread-out distribution when gravity pulls down
- An “8” will typically have two peaks when gravity pulls down (from its two loops)
- A “7” will have a distinctive distribution when gravity pulls from the left
By capturing these characteristics from all four directions and encoding them as Gaussian parameters, we create a rich feature set that a classifier can use to distinguish between digits.
Beyond MNIST
This “gravitational feature extraction” technique could potentially be applied to other image classification problems. It’s particularly interesting for cases where:
The objects have distinctive shapes Computational resources are limited You want a more interpretable model than deep learning often provides
The physics-inspired approach also reminds us that sometimes looking at a problem through a different lens - in this case, literally imagining pixels as sand particles - can lead to creative solutions.
Conclusion
Machine learning doesn’t always have to mean complex neural networks and black-box models. Sometimes, a creative approach inspired by physical intuition can yield surprisingly good results.
The Gravity MNIST method is a perfect example of cross-disciplinary thinking: taking a concept from physics (gravity) and applying it to a classic machine learning problem. It’s a reminder that creativity still has an important place in the age of large models and big data.
Would this approach ever beat state-of-the-art deep learning models on MNIST? Probably not. But at 90% accuracy with a beautifully simple approach, it’s a wonderful example of how thinking differently can lead to elegant solutions.