heatmap: 在 matplotlib 中创建热力图

使用 matplotlib 创建热力图的实用函数

from mlxtend.plotting import heatmap

概览

一个简单的函数,默认使用 matplotlib 和 Viridis 调色板从 NumPy 数组创建美观的热力图。

参考资料

  • -

示例 1 -- 简单热力图

from mlxtend.plotting import heatmap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

np.random.seed(123)
some_array = np.random.random((20, 30))
heatmap(some_array, figsize=(20, 10))
plt.show()

png

通过设置 cell_values=False 可以隐藏单元格值

heatmap(some_array, figsize=(20, 10), cell_values=False)
plt.show()

png

示例 2 -- 将相关矩阵作为热力图

from mlxtend.plotting import heatmap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.read_csv('https://raw.githubusercontent.com/rasbt/'
                 'python-machine-learning-book-2nd-edition'
                 '/master/code/ch10/housing.data.txt',
                 header=None,
                 sep='\s+')

df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS', 
              'NOX', 'RM', 'AGE', 'DIS', 'RAD', 
              'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']
df.head()
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT MEDV
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296.0 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242.0 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242.0 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222.0 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222.0 18.7 396.90 5.33 36.2
from matplotlib import cm

cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']

corrmat = np.corrcoef(df[cols].values.T)
fig, ax = heatmap(corrmat, column_names=cols, row_names=cols, cmap=cm.PiYG)

# set colorbar cutoff at -1, 1
for im in ax.get_images():
    im.set_clim(-1, 1)

plt.show()

png

API

heatmap(matrix, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=True, row_names=None, column_names=None, column_name_rotation=45, cell_values=True, cell_fmt='.2f', cell_font_size=None, text_color_threshold=None)

使用 matplotlib 绘制热力图。

参数

  • conf_mat : 类数组对象, 形状 = [行数, 列数]

    任意的二维数组。

  • hide_spines : 布尔值 (默认值: False)

    如果为 True,则隐藏坐标轴的脊线。

  • hide_ticks : 布尔值 (默认值: False)

    如果为 True,则隐藏坐标轴的刻度。

  • figsize : 元组 (默认值: (2.5, 2.5))

    图的高度和宽度

  • cmap : matplotlib 颜色映射 (默认值: None)

    如果为 None,则使用 matplotlib.pyplot.cm.viridis

  • colorbar : 布尔值 (默认值: True)

    如果为 True,则显示颜色条

  • row_names : 类数组对象, 形状 = [行数] (默认值: None)

    用作 y 轴刻度标签的行名称列表。

  • column_names : 类数组对象, 形状 = [列数] (默认值: None)

    用作 x 轴刻度标签的列名称列表。

  • column_name_rotation : 整型 (默认值: 45)

    旋转列 x 轴刻度标签的角度(度)。

  • cell_values : 布尔值 (默认值: True)

    如果为 True,则绘制单元格值。

  • cell_fmt : 字符串 (默认值: '.2f')

    单元格值的格式规范 (如果 cell_values=True)

  • cell_font_size : 整型 (默认值: None)

    单元格值的字体大小 (如果 cell_values=True)

  • text_color_threshold : 浮点型 (默认值: None)

    文本注释的黑/白文本颜色阈值。默认值 (None) 尝试使用 np.max(normed_matrix) / 2 自动推断一个好的阈值。

返回值

  • fig, ax : matplotlib.pyplot 子图对象

    子图的图和坐标轴元素。

示例

有关用法示例,请参阅 https://mlxtend.cn/mlxtend/user_guide/plotting/heatmap/