🧠 AI with Python – 📊 Interactive Confusion Matrix with Plotly


Description:

Evaluating a machine learning model often begins with a confusion matrix, which summarizes the model’s predictions across different classes. But traditional confusion matrices are static — making it difficult to explore exact values or interact with the visualization.

In this project, we build an interactive confusion matrix using Plotly, allowing you to hover, zoom, and inspect predictions in a much more dynamic way. This is especially useful for dashboards, presentations, and in-depth model analysis.


Understanding the Problem

Confusion matrices provide a detailed breakdown of classification performance by showing:

  • True Positives
  • False Positives
  • True Negatives
  • False Negatives

For multi-class problems, the matrix expands to show predictions for each class. While static heatmaps work, interactivity makes interpretation easier — especially when dealing with large datasets or subtle differences between classes.

Plotly brings clarity through:

  • hover tooltips
  • clearer color gradients
  • zoom & pan interactions
  • clean, customizable visuals

1. Load and Explore the Dataset

We use the well-known Iris dataset, which contains:

  • 150 samples
  • 4 numerical features
  • 3 flower species
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

This gives us a balanced dataset for multi-class classification.


2. Train a Classification Model

We train a simple Logistic Regression model.

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(max_iter=5000)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)

3. Generate the Confusion Matrix

We compute the confusion matrix using scikit-learn.

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test, y_pred)
labels = iris.target_names

4. Build the Interactive Confusion Matrix (Plotly)

Plotly requires the confusion matrix to be wrapped in a DataFrame with readable labels.

import pandas as pd
import plotly.express as px

cm_df = pd.DataFrame(
    cm,
    index=[f"Actual: {label}" for label in labels],
    columns=[f"Predicted: {label}" for label in labels]
)

fig = px.imshow(
    cm_df,
    text_auto=True,
    color_continuous_scale="Purples",
    title="Interactive Confusion Matrix – Iris Classification",
)

fig.update_layout(
    xaxis_title="Predicted Label",
    yaxis_title="Actual Label",
    title_x=0.5
)

fig.show()

The result is a fully interactive visualization suitable for dashboards, demos, and presentations.


Key Takeaways

  1. Interactive visualizations improve interpretability, especially for multi-class models.
  2. Plotly provides hover tooltips and dynamic navigation — features missing from static plots.
  3. Confusion matrices are essential for understanding class-wise model performance.
  4. This approach integrates seamlessly with Dash, Streamlit, and web-based dashboards.
  5. Plotly’s interactive features make error analysis more intuitive and precise.

Conclusion

Plotly transforms the traditional confusion matrix into a powerful interactive tool.

With just a few lines of code, you can create a dynamic visualization that highlights model performance, misclassifications, and class distributions.

This technique is especially valuable for reporting, debugging, teaching, or integrating into production dashboards.


Code Snippet:

import numpy as np
import plotly.express as px
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix


# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)


model = LogisticRegression(max_iter=5000)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)


cm = confusion_matrix(y_test, y_pred)
labels = iris.target_names


# Prepare heatmap as a DataFrame for Plotly
import pandas as pd

cm_df = pd.DataFrame(
    cm,
    index=[f"Actual: {label}" for label in labels],
    columns=[f"Predicted: {label}" for label in labels]
)

fig = px.imshow(
    cm_df,
    text_auto=True,
    color_continuous_scale="Purples",
    title="Interactive Confusion Matrix – Iris Classification",
)

fig.update_layout(
    xaxis_title="Predicted Label",
    yaxis_title="Actual Label",
    title_x=0.5
)

fig.show()

Link copied!

Comments

Add Your Comment

Comment Added!