plot_learning_curves: 绘制来自训练集和测试集的学习曲线
绘制分类器学习曲线的函数。学习曲线对于分析模型是否存在过拟合或欠拟合(高方差或高偏差)非常有用。可以通过以下方式导入该函数:
from mlxtend.plotting import plot_learning_curves
此函数使用基于训练集和测试集(或验证集)的传统留出法。测试集保持不变,同时训练集大小逐渐增加。模型在训练集(大小可变)上拟合,并在同一测试集上进行评估。
学习曲线可以按如下方式用于诊断过拟合
- 如果训练集和测试集的性能之间存在较大差距,则模型可能存在过拟合。
- 如果训练误差和测试误差都非常大,则模型可能对数据存在欠拟合。
学习曲线也可用于判断收集更多数据是否有用。更多内容请参见下面的示例 1。
参考资料
-
示例 1
以下代码展示了如何为 MNIST 数据集的 5000 个样本子集构建学习曲线。其中 4000 个样本用于训练,1000 个样本保留用于测试。
from mlxtend.plotting import plot_learning_curves
import matplotlib.pyplot as plt
from mlxtend.data import mnist_data
from mlxtend.preprocessing import shuffle_arrays_unison
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
# Loading some example data
X, y = mnist_data()
X, y = shuffle_arrays_unison(arrays=[X, y], random_seed=123)
X_train, X_test = X[:4000], X[4000:]
y_train, y_test = y[:4000], y[4000:]
clf = KNeighborsClassifier(n_neighbors=7)
plot_learning_curves(X_train, y_train, X_test, y_test, clf)
plt.show()
从上图可以看出,KNN 模型可以从更多训练数据中受益。也就是说,曲线的斜率表明,如果我们有更大的训练集,测试集误差可能会进一步降低。
此外,根据训练集和测试集性能之间的差距,模型略有过拟合。这可以通过增加 KNN 中的邻居数量(n_neighbors
)来解决。
虽然这与分析分类器性能无关,但大约 20% 训练集大小的区域显示模型存在欠拟合(训练误差和测试误差都很大),这可能是由于数据集规模太小造成的。
API
plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o', test_marker='^', scoring='misclassification error', suppress_plot=False, print_model=True, title_fontsize=12, style='default', legend_loc='best')
绘制分类器的学习曲线。
参数
-
X_train
: 类数组对象,形状 = [n_samples, n_features]训练数据集的特征矩阵。
-
y_train
: 类数组对象,形状 = [n_samples]训练数据集的真实类别标签。
-
X_test
: 类数组对象,形状 = [n_samples, n_features]测试数据集的特征矩阵。
-
y_test
: 类数组对象,形状 = [n_samples]测试数据集的真实类别标签。
-
clf
: 分类器对象。必须有 .predict 和 .fit 方法。 -
train_marker
: str (默认值: 'o')训练集线图的标记。
-
test_marker
: str (默认值: '^')测试集线图的标记。
-
scoring
: str (默认值: 'misclassification error')如果不是 'misclassification error',接受以下指标(来自 scikit-learn):{'accuracy', 'average_precision', 'f1_micro', 'f1_macro', 'f1_weighted', 'f1_samples', 'log_loss', 'precision', 'recall', 'roc_auc', 'adjusted_rand_score', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'r2'}
-
suppress_plot=False
: bool (默认值: False)如果为 True,则抑制 matplotlib 绘图。推荐用于测试目的。
-
print_model
: bool (默认值: True)如果为 True,则在图标题中打印模型参数。
-
style
: str (默认值: 'default')Matplotlib 样式。更多样式请参见 https://matplotlib.net.cn/stable/gallery/style_sheets/style_sheets_reference.html
-
legend_loc
: str (默认值: 'best')图例放置位置:{'best', 'upper left', 'upper right', 'lower left', 'lower right'}
返回值
errors
: (training_error, test_error): 列表元组
示例
有关使用示例,请参见 https://mlxtend.cn/mlxtend/user_guide/plotting/plot_learning_curves/