confusion_matrix:创建用于模型评估的混淆矩阵

用于生成混淆矩阵的函数。

from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix

概述

混淆矩阵

混淆矩阵(或错误矩阵)是总结二元分类任务中分类器性能的一种方式。这个方阵由列和行组成,以绝对或相对的“实际类别”与“预测类别”比率的形式列出实例数量。

为类别 1 的标签,且为第二类的标签,或在多类别设置中,为所有非类别 1 的标签。

参考

  • -

示例 1 - 二元分类

from mlxtend.evaluate import confusion_matrix

y_target =    [0, 0, 1, 0, 0, 1, 1, 1]
y_predicted = [1, 0, 1, 0, 0, 0, 0, 1]

cm = confusion_matrix(y_target=y_target, 
                      y_predicted=y_predicted)
cm
array([[3, 1],
       [2, 2]])

要使用 matplotlib 可视化混淆矩阵,请参阅工具函数 mlxtend.plotting.plot_confusion_matrix

import matplotlib.pyplot as plt
from mlxtend.plotting import plot_confusion_matrix

fig, ax = plot_confusion_matrix(conf_mat=cm)
plt.show()

png

示例 2 - 多类别分类

from mlxtend.evaluate import confusion_matrix

y_target =    [1, 1, 1, 0, 0, 2, 0, 3]
y_predicted = [1, 0, 1, 0, 0, 2, 1, 3]

cm = confusion_matrix(y_target=y_target, 
                      y_predicted=y_predicted, 
                      binary=False)
cm
array([[2, 1, 0, 0],
       [1, 2, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 1]])

要使用 matplotlib 可视化混淆矩阵,请参阅工具函数 mlxtend.plotting.plot_confusion_matrix

import matplotlib.pyplot as plt
from mlxtend.evaluate import confusion_matrix

fig, ax = plot_confusion_matrix(conf_mat=cm)
plt.show()

png

示例 3 - 从多类别到二元

通过设置 binary=True,所有非正类别标签的类标签都会被汇总为类别 0。正类别标签变为类别 1。

import matplotlib.pyplot as plt
from mlxtend.evaluate import confusion_matrix

y_target =    [1, 1, 1, 0, 0, 2, 0, 3]
y_predicted = [1, 0, 1, 0, 0, 2, 1, 3]

cm = confusion_matrix(y_target=y_target, 
                      y_predicted=y_predicted, 
                      binary=True, 
                      positive_label=1)
cm
array([[4, 1],
       [1, 2]])

要使用 matplotlib 可视化混淆矩阵,请参阅工具函数 mlxtend.plotting.plot_confusion_matrix

from mlxtend.plotting import plot_confusion_matrix

fig, ax = plot_confusion_matrix(conf_mat=cm)
plt.show()

png

API

confusion_matrix(y_target, y_predicted, binary=False, positive_label=1)

计算混淆矩阵/列联表。

参数

  • y_target : 类似数组,形状=[n_samples]

    真实类别标签。

  • y_predicted : 类似数组,形状=[n_samples]

    预测类别标签。

  • binary : 布尔值 (默认: False)

    将多类别问题映射到二元混淆矩阵,其中正类别为 1,所有其他类别为 0。

  • positive_label : 整型 (默认: 1)

    正类别的标签。

返回值

  • mat : 类似数组,形状=[n_classes, n_classes]

示例

有关使用示例,请参见 https://mlxtend.cn/mlxtend/user_guide/evaluate/confusion_matrix/