RandomHoldoutSplit: 将数据集分割成训练子集和验证子集用于验证

随机将数据集分割为训练子集和验证子集用于验证。

from mlxtend.evaluate import RandomHoldoutSplit

概述

RandomHoldoutSplit 类是 scikit-learn 中 KFold 类的一个替代品,它将数据集分割为训练子集和验证子集,但不进行轮换。RandomHoldoutSplit 可作为 scikit-learn 的 GridSearchCV 等函数中 cv 参数的输入。

RandomHoldoutSplit 中的“随机”一词源于分割由 random_seed 指定,而不是像 mlxtend 中的 PredefinedHoldoutSplit 类那样手动指定训练集和验证集的索引。

示例 1 -- 迭代 RandomHoldoutSplit

from mlxtend.evaluate import RandomHoldoutSplit
from mlxtend.data import iris_data

X, y = iris_data()
h_iter = RandomHoldoutSplit(valid_size=0.3, random_seed=123)

cnt = 0
for train_ind, valid_ind in h_iter.split(X, y):
    cnt += 1
    print(cnt)
1
print(train_ind[:5])
print(valid_ind[:5])
[ 60  16  88 130   6]
[ 72 125  80  86 117]

示例 2 -- 在 GridSearch 中使用 RandomHoldoutSplit

from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from mlxtend.evaluate import RandomHoldoutSplit
from mlxtend.data import iris_data

X, y = iris_data()

params = {'n_neighbors': [1, 2, 3, 4, 5]}

grid = GridSearchCV(KNeighborsClassifier(),
                    param_grid=params,
                    cv=RandomHoldoutSplit(valid_size=0.3, random_seed=123))

grid.fit(X, y)
GridSearchCV(cv=<mlxtend.evaluate.holdout.RandomHoldoutSplit object at 0x7fae707f6610>,
             estimator=KNeighborsClassifier(),
             param_grid={'n_neighbors': [1, 2, 3, 4, 5]})

API

RandomHoldoutSplit(valid_size=0.5, random_seed=None, stratify=False)

用于 sklearn 的 GridSearchCV 等函数的训练/验证集分割器。

Provides train/validation set indices to split a dataset
into train/validation sets using random indices.

参数

  • valid_size : 浮点数 (默认值: 0.5)

    指定为验证样本的示例比例。1-valid_size 将自动分配为训练集样本。

  • random_seed : 整数 (默认值: None)

    用于将数据分割为训练集和验证集分区的随机种子。

  • stratify : 布尔值 (默认值: False)

    True 或 False,是否执行分层分割。

示例

有关用法示例,请参见 https://mlxtend.cn/mlxtend/user_guide/evaluate/RandomHoldoutSplit/

方法


get_n_splits(X=None, y=None, groups=None)

返回交叉验证器中的分割迭代次数。

参数

  • X : 对象

    始终被忽略,仅为兼容性而存在。

  • y : 对象

    始终被忽略,仅为兼容性而存在。

  • groups : 对象

    始终被忽略,仅为兼容性而存在。

返回

  • n_splits : 1

    返回交叉验证器中的分割迭代次数。始终返回 1。


split(X, y, groups=None)

生成索引以将数据分割为训练集和测试集。

参数

  • X : 类似数组的对象, 形状 (样本数, 特征数)

    训练数据,其中样本数是训练样本的数量,特征数是特征的数量。

  • y : 类似数组的对象, 形状 (样本数,)

    监督学习问题的目标变量。分层基于 y 标签进行。

  • groups : 对象

    始终被忽略,仅为兼容性而存在。

产生

  • train_index : ndarray

    该分割的训练集索引。

  • valid_index : ndarray

    该分割的验证集索引。