plot_confusion_matrix: 可视化混淆矩阵

通过 matplotlib 可视化混淆矩阵的实用函数

from mlxtend.plotting import plot_confusion_matrix

概览

混淆矩阵

有关混淆矩阵的更多信息,请参阅 mlxtend.evaluate.confusion_matrix

参考资料

  • -

示例 1 - 二元

from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

binary1 = np.array([[4, 1],
                    [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary1)
plt.show()

png

binary2 = np.array([[21, 1],
                    [3, 1]])

fig, ax = plot_confusion_matrix(conf_mat=binary2, figsize=(2, 2))
plt.show()

png

示例 2 - 带有颜色条的二元绝对值和相对值

binary = np.array([[4, 1],
                   [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary,
                                show_absolute=True,
                                show_normed=True,
                                colorbar=True)
plt.show()

png

示例 3 - 多类别相对值

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True)
plt.show()

png

示例 4 - 添加类别名称

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

class_names = ['class a', 'class b', 'class c', 'class d']

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True,
                                class_names=class_names)
plt.show()

png

示例 5 - 更改颜色映射和字体颜色

可以通过 cmap 参数选择 Matplotlib 颜色映射作为备选颜色映射。颜色映射列表可以在此处找到:https://matplotlib.net.cn/stable/tutorials/colors/colormaps.html

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                cmap='summer')

plt.show()

png

如上所示,字体颜色阈值可能不适用于某些颜色映射。默认情况下,所有大于单元格最大值 0.5 倍的值都转换为白色,所有小于或等于单元格最大值 0.5 倍的值都转换为黑色。

如果您想将所有值都更改为例如白色,可以将颜色阈值设置为负数。或者,如果您想使所有字体颜色都为黑色,选择一个大于或等于 1 的阈值。

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                fontcolor_threshold=1,
                                cmap='summer')

plt.show()

png

示例 6 - 归一化颜色映射以突出非对角线元素

假设我们有一个高精度分类器的混淆矩阵如下

class_dict = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog'}

cmat = np.array([[972, 0, 1, 1, 1, 1, 3],
                 [0, 1123, 3, 1, 0, 1, 2],
                 [2, 0, 1025, 0, 0, 0, 1],
                 [0, 0, 0, 1005, 0, 2, 0],
                 [0, 1, 1, 0, 967, 0, 4],
                 [0, 0, 0, 6, 0, 881, 3],
                 [2, 3, 0, 1, 3, 4, 941]])

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
)

png

很难注意到模型出错的单元格。使用对数归一化颜色映射,非对角线上的这些错误变得更容易一目了然

import matplotlib

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
    norm_colormap=matplotlib.colors.LogNorm()  
)

png

API

plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None, fontcolor_threshold=0.5)

通过 matplotlib 绘制混淆矩阵。

参数

  • conf_mat : 类似数组,形状 = [n_classes, n_classes]

    来自 evaluate.confusion matrix 的混淆矩阵。

  • hide_spines : 布尔值 (默认: False)

    如果为 True,则隐藏坐标轴边框。

  • hide_ticks : 布尔值 (默认: False)

    如果为 True,则隐藏坐标轴刻度线

  • figsize : 元组 (默认: (2.5, 2.5))

    图形的高度和宽度

  • cmap : matplotlib 颜色映射 (默认: None)

    如果为 None,则使用 matplotlib.pyplot.cm.Blues

  • colorbar : 布尔值 (默认: False)

    如果为 True,则显示颜色条

  • show_absolute : 布尔值 (默认: True)

    如果为 True,则显示混淆矩阵的绝对值系数。show_absoluteshow_normed 至少有一个必须为 True。

  • show_normed : 布尔值 (默认: False)

    如果为 True,则显示归一化混淆矩阵系数。归一化混淆矩阵系数表示每个类别中被分配到正确标签的训练样本比例。show_absoluteshow_normed 至少有一个必须为 True。

  • class_names : 类似数组,形状 = [n_classes] (默认: None)

    类别名称列表。如果不为 None,则将刻度设置为这些值。

  • figure : None 或 Matplotlib 图形对象 (默认: None)

    如果为 None,则将创建一个新图形。

  • axis : None 或 Matplotlib 图形坐标轴对象 (默认: None)

    如果为 None,则将创建一个新坐标轴。

  • fontcolor_threshold : 浮点数 (默认: 0.5)

    设置单元格字体颜色选择黑白的阈值。默认情况下,所有大于单元格最大值 0.5 倍的值都转换为白色,所有小于或等于单元格最大值 0.5 倍的值都转换为黑色。

返回值

  • fig, ax : matplotlib.pyplot 子图对象

    子图的图形和坐标轴元素。

示例

有关使用示例,请参阅 https://mlxtend.cn/mlxtend/user_guide/plotting/plot_confusion_matrix/