PredefinedHoldoutSplit:兼容 scikit-learn 的留出法工具

根据用户指定的索引,将数据集分割为训练集和验证集用于验证。

from mlxtend.evaluate import PredefinedHoldoutSplit

概述

PredefinedHoldoutSplit 类是 scikit-learn 的 KFold 类的一种替代方案,它根据用户指定的验证索引将数据集分割为训练子集和验证子集,且不进行轮换。PredefinedHoldoutSplit 可用作 scikit-learn 的 GridSearchCV 等中的 cv 参数。

对于执行随机分割,请参见相关的 RandomHoldoutSplit 类。

示例 1 -- 遍历 PredefinedHoldoutSplit

from mlxtend.evaluate import PredefinedHoldoutSplit
from mlxtend.data import iris_data

X, y = iris_data()
h_iter = PredefinedHoldoutSplit(valid_indices=[0, 1, 99])

cnt = 0
for train_ind, valid_ind in h_iter.split(X, y):
    cnt += 1
    print(cnt)
1
print(train_ind[:5])
print(valid_ind[:5])
[2 3 4 5 6]
[ 0  1 99]

示例 2 -- GridSearch 中的 PredefinedHoldoutSplit

from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from mlxtend.evaluate import PredefinedHoldoutSplit
from mlxtend.data import iris_data

X, y = iris_data()


params = {'n_neighbors': [1, 2, 3, 4, 5]}

grid = GridSearchCV(KNeighborsClassifier(),
                    param_grid=params,
                    cv=PredefinedHoldoutSplit(valid_indices=[0, 1, 99]))

grid.fit(X, y)
GridSearchCV(cv=<mlxtend.evaluate.holdout.PredefinedHoldoutSplit object at 0x7fb300565610>,
             estimator=KNeighborsClassifier(),
             param_grid={'n_neighbors': [1, 2, 3, 4, 5]})

API

PredefinedHoldoutSplit(valid_indices)

用于 sklearn 的 GridSearchCV 等的训练/验证集分割器。

Uses user-specified train/validation set indices to split a dataset
into train/validation sets using user-defined or random
indices.

参数

  • valid_indices : 类似数组的对象,形状为 (num_examples,)

    训练集中用于验证的训练样本的索引。训练集中所有其他索引用于模型拟合的训练子集。

示例

有关使用示例,请参见 https://mlxtend.cn/mlxtend/user_guide/evaluate/PredefinedHoldoutSplit/

方法


get_n_splits(X=None, y=None, groups=None)

返回交叉验证器中的分割迭代次数

参数

  • X : 对象

    始终被忽略,仅为兼容性而存在。

  • y : 对象

    始终被忽略,仅为兼容性而存在。

  • groups : 对象

    始终被忽略,仅为兼容性而存在。

返回值

  • n_splits : 1

    返回交叉验证器中的分割迭代次数。始终返回 1。


split(X, y, groups=None)

生成将数据分割为训练集和测试集的索引。

参数

  • X : 类似数组的对象,形状为 (num_examples, num_features)

    训练数据,其中 num_examples 是样本数,num_features 是特征数。

  • y : 类似数组的对象,形状为 (num_examples,)

    监督学习问题的目标变量。分层是基于 y 标签进行的。

  • groups : 对象

    始终被忽略,仅为兼容性而存在。

生成

  • train_index : ndarray

    该分割的训练集索引。

  • valid_index : ndarray

    该分割的验证集索引。