Skip to content

swanlab.confusion_matrix

Github Source Code

python
confusion_matrix(  y_true: Union[List, np.ndarray],  y_pred: Union[List, np.ndarray],  class_names: List[str] = None, ) -> None
ParameterDescription
y_true(Union[List, np.ndarray]) True labels, the actual class labels in classification problems
y_pred(Union[List, np.ndarray]) Predicted labels, the class labels predicted by the model
class_names(List[str]) List of class names used to display class labels in the confusion matrix. If None, numeric indices will be used as labels

Introduction

Draw a confusion matrix to evaluate the performance of classification models. The confusion matrix shows the correspondence between model predictions and true labels, providing an intuitive display of prediction accuracy and error types for each class.

The confusion matrix is a fundamental tool for evaluating classification model performance, especially suitable for multi-classification problems.

Basic Usage

python
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import xgboost as xgb import swanlab  # Load iris dataset iris_data = load_iris() X = iris_data.data y = iris_data.target class_names = iris_data.target_names.tolist()  # Split training and test sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # Train model model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(class_names)) model.fit(X_train, y_train)  # Get predictions y_pred = model.predict(X_test)  # Initialize SwanLab swanlab.init(project="Confusion-Matrix-Demo", experiment_name="Confusion-Matrix-Example")  # Log confusion matrix swanlab.log({  "confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, class_names) })  swanlab.finish()

Using Custom Class Names

python
# Define custom class names custom_class_names = ["Class A", "Class B", "Class C"]  # Log confusion matrix confusion_matrix = swanlab.confusion_matrix(y_test, y_pred, custom_class_names) swanlab.log({"confusion_matrix_custom": confusion_matrix})

Without Class Names

python
# Don't specify class names, numeric indices will be used confusion_matrix = swanlab.confusion_matrix(y_test, y_pred) swanlab.log({"confusion_matrix_default": confusion_matrix})

Binary Classification Example

python
import numpy as np from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split import xgboost as xgb import swanlab  # Generate binary classification data X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)  # Train model model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss') model.fit(X_train, y_train)  # Get predictions y_pred = model.predict(X_test)  # Log confusion matrix swanlab.log({  "confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, ["Negative", "Positive"]) })

Notes

  1. Data Format: y_true and y_pred can be lists or numpy arrays
  2. Multi-class Support: This function supports both binary and multi-classification problems
  3. Class Names: The length of class_names should match the number of classes
  4. Dependencies: Requires installation of scikit-learn and pyecharts packages
  5. Coordinate Axes: sklearn's confusion_matrix has (0,0) at the top-left corner, while pyecharts heatmap has it at the bottom-left corner. The function automatically handles coordinate conversion
  6. Matrix Interpretation: In the confusion matrix, rows represent true labels and columns represent predicted labels