🧠 AI with Python – 🖼️🔢 Image Classifier with KNN on Digits Dataset


Description:

Image classification is one of the most popular tasks in machine learning. Before jumping into deep learning with CNNs, it’s useful to see how classical ML algorithms can handle image data.

In this post, we’ll use K-Nearest Neighbors (KNN) on scikit-learn’s digits dataset to classify handwritten digits from 0–9.


Why KNN for Images?

  • Instance-based learning: KNN classifies a new sample by looking at its closest training samples.
  • No training needed: It just stores the training set and uses distance metrics at prediction time.
  • Great for baselines: Works surprisingly well on simple datasets like digits.

Loading the Dataset

The digits dataset comes preloaded in scikit-learn. It contains 1,797 samples of 8×8 grayscale images (values from 0 to 16). Each image is “flattened” into a 64-dimensional feature vector.

from sklearn.datasets import load_digits

digits = load_digits()
X, y = digits.data, digits.target

Training the Classifier

We’ll split the dataset into train and test sets, then fit a KNN classifier with k=3.

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)

knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

Evaluating the Model

We predict on the test set and compute accuracy.

from sklearn.metrics import accuracy_score

y_pred = knn.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

A classification report or confusion matrix helps you see which digits are harder for the model to distinguish.


Visualizing Predictions

It’s always helpful to see actual digits. We can display a few test images along with their predicted and true labels.

import matplotlib.pyplot as plt

plt.imshow(X_test[0].reshape(8, 8), cmap="gray")
print("Predicted:", y_pred[0], " True:", y_test[0])

Sample Output

On the digits dataset, you can typically expect 90–97% accuracy with KNN.

Example:

Accuracy: 0.9622

Predicted: 8   True: 8
Predicted: 1   True: 1
Predicted: 4   True: 4

Key Takeaways

  • KNN provides a fast baseline for image classification.
  • Accuracy is good on simple datasets, but scalability is limited (prediction gets slower with large datasets).
  • A great first step before experimenting with more advanced models like SVMs or CNNs.

Code Snippet:

# Import dataset, split utility, model, and metrics
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay

# Plotting
import matplotlib.pyplot as plt


# Load digits dataset
digits = load_digits()
X, y = digits.data, digits.target   # X: (n_samples, 64), y: digit labels 0..9

# (Optional) Peek at shapes
print("Feature matrix shape:", X.shape)
print("Targets shape:", y.shape)


# Stratified split keeps class balance across train/test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)


# Create KNN with k=3 neighbors (good starting point)
knn = KNeighborsClassifier(n_neighbors=3)

# Fit on training data
knn.fit(X_train, y_train)


# Predict on the held-out test set
y_pred = knn.predict(X_test)

# Overall accuracy
acc = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {acc:.4f}")

# Per-class precision/recall/F1
print("\nClassification Report:\n", classification_report(y_test, y_pred))


# Confusion matrix display
disp = ConfusionMatrixDisplay.from_predictions(y_test, y_pred, cmap="Blues")
disp.ax_.set_title("KNN (k=3) — Confusion Matrix")
plt.tight_layout()
plt.show()


# Plot a few test images with predictions
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
fig.suptitle("Sample Predictions (Top: image | Bottom: pred/true)", y=1.05)

for ax_idx, ax in enumerate(axes.flat):
    # Choose an index from the test set
    idx = ax_idx
    image = X_test[idx].reshape(8, 8)   # reshape back to 8x8
    ax.imshow(image, cmap="gray")
    ax.axis("off")

# Add a second row of text labels
for i, ax in enumerate(axes.flat):
    if i >= 10: break
plt.figure(figsize=(10, 1))
for i in range(10):
    plt.text(i + 0.25, 0.5, f"pred:{y_pred[i]} | true:{y_test[i]}", fontsize=10)

plt.axis("off")
plt.show()

Link copied!

Comments

Add Your Comment

Comment Added!