heatmap: 在 matplotlib 中创建热力图
使用 matplotlib 创建热力图的实用函数
from mlxtend.plotting import heatmap
概览
一个简单的函数,默认使用 matplotlib 和 Viridis 调色板从 NumPy 数组创建美观的热力图。
参考资料
- -
示例 1 -- 简单热力图
from mlxtend.plotting import heatmap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(123)
some_array = np.random.random((20, 30))
heatmap(some_array, figsize=(20, 10))
plt.show()
通过设置 cell_values=False
可以隐藏单元格值
heatmap(some_array, figsize=(20, 10), cell_values=False)
plt.show()
示例 2 -- 将相关矩阵作为热力图
from mlxtend.plotting import heatmap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/rasbt/'
'python-machine-learning-book-2nd-edition'
'/master/code/ch10/housing.data.txt',
header=None,
sep='\s+')
df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS',
'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']
df.head()
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | MEDV | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296.0 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242.0 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242.0 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222.0 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222.0 | 18.7 | 396.90 | 5.33 | 36.2 |
from matplotlib import cm
cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']
corrmat = np.corrcoef(df[cols].values.T)
fig, ax = heatmap(corrmat, column_names=cols, row_names=cols, cmap=cm.PiYG)
# set colorbar cutoff at -1, 1
for im in ax.get_images():
im.set_clim(-1, 1)
plt.show()
API
heatmap(matrix, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=True, row_names=None, column_names=None, column_name_rotation=45, cell_values=True, cell_fmt='.2f', cell_font_size=None, text_color_threshold=None)
使用 matplotlib 绘制热力图。
参数
-
conf_mat
: 类数组对象, 形状 = [行数, 列数]任意的二维数组。
-
hide_spines
: 布尔值 (默认值: False)如果为 True,则隐藏坐标轴的脊线。
-
hide_ticks
: 布尔值 (默认值: False)如果为 True,则隐藏坐标轴的刻度。
-
figsize
: 元组 (默认值: (2.5, 2.5))图的高度和宽度
-
cmap
: matplotlib 颜色映射 (默认值:None
)如果为
None
,则使用 matplotlib.pyplot.cm.viridis -
colorbar
: 布尔值 (默认值: True)如果为 True,则显示颜色条
-
row_names
: 类数组对象, 形状 = [行数] (默认值: None)用作 y 轴刻度标签的行名称列表。
-
column_names
: 类数组对象, 形状 = [列数] (默认值: None)用作 x 轴刻度标签的列名称列表。
-
column_name_rotation
: 整型 (默认值: 45)旋转列 x 轴刻度标签的角度(度)。
-
cell_values
: 布尔值 (默认值: True)如果为 True,则绘制单元格值。
-
cell_fmt
: 字符串 (默认值: '.2f')单元格值的格式规范 (如果
cell_values=True
) -
cell_font_size
: 整型 (默认值: None)单元格值的字体大小 (如果
cell_values=True
) -
text_color_threshold
: 浮点型 (默认值: None)文本注释的黑/白文本颜色阈值。默认值 (None) 尝试使用
np.max(normed_matrix) / 2
自动推断一个好的阈值。
返回值
-
fig, ax
: matplotlib.pyplot 子图对象子图的图和坐标轴元素。
示例
有关用法示例,请参阅 https://mlxtend.cn/mlxtend/user_guide/plotting/heatmap/