loadlocal_mnist: 用于从原始 ubyte 文件加载 MNIST 的函数
一个从字节格式将 MNIST
数据集加载到 NumPy 数组中的实用函数。
from mlxtend.data import loadlocal_mnist
概述
MNIST 数据集是由美国国家标准与技术研究院 (NIST) 的两个数据集构建而成。训练集包含来自 250 个不同的人的手写数字,其中 50% 是高中生,50% 是人口普查局的雇员。请注意,测试集包含来自遵循相同比例的不同人员的手写数字。
MNIST 数据集可在 https://yann.lecun.com/exdb/mnist/ 公开获取,包含以下四个部分: - 训练集图像: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 60,000 样本) - 训练集标签: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 60,000 标签) - 测试集图像: t10k-images-idx3-ubyte.gz (1.6 MB, 7.8 MB, 解压后 10,000 样本) - 测试集标签: t10k-labels-idx1-ubyte.gz (5 KB, 解压后 10 KB, 10,000 标签)
特征
每个特征向量 (特征矩阵中的一行) 包含 784 个像素 (强度) -- 由原始 28x28 像素图像展开。
-
样本数: 50000 张图像
-
目标变量 (离散): 均匀分布的类别标签 0-9,对应于图像中显示的手写数字。
参考文献
- 来源: https://yann.lecun.com/exdb/mnist/
- Y. LeCun and C. Cortes. Mnist handwritten digit database. AT&T Labs [Online]. Available: https://yann. lecun. com/exdb/mnist, 2010.
示例 1 部分 1 - 下载 MNIST 数据集
1) 从 Y. LeCun 的网站下载 MNIST 文件
- https://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
- https://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
- https://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
- https://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
例如,通过
curl -O https://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
2) 解压下载的 gzip 档案
例如,通过
gunzip t*-ubyte.gz
示例 1 部分 2 - 将 MNIST 加载到 NumPy 数组中
from mlxtend.data import loadlocal_mnist
import platform
if not platform.system() == 'Windows':
X, y = loadlocal_mnist(
images_path='train-images-idx3-ubyte',
labels_path='train-labels-idx1-ubyte')
else:
X, y = loadlocal_mnist(
images_path='train-images.idx3-ubyte',
labels_path='train-labels.idx1-ubyte')
print('Dimensions: %s x %s' % (X.shape[0], X.shape[1]))
print('\n1st row', X[0])
Dimensions: 60000 x 784
1st row [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 3 18 18 18 126 136 175 26 166 255
247 127 0 0 0 0 0 0 0 0 0 0 0 0 30 36 94 154
170 253 253 253 253 253 225 172 253 242 195 64 0 0 0 0 0 0
0 0 0 0 0 49 238 253 253 253 253 253 253 253 253 251 93 82
82 56 39 0 0 0 0 0 0 0 0 0 0 0 0 18 219 253
253 253 253 253 198 182 247 241 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 80 156 107 253 253 205 11 0 43 154
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 14 1 154 253 90 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 139 253 190 2 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 11 190 253 70 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 35 241
225 160 108 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 81 240 253 253 119 25 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 45 186 253 253 150 27 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 16 93 252 253 187
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 249 253 249 64 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 46 130 183 253
253 207 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 39 148 229 253 253 253 250 182 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 24 114 221 253 253 253
253 201 78 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 23 66 213 253 253 253 253 198 81 2 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 18 171 219 253 253 253 253 195
80 9 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
55 172 226 253 253 253 253 244 133 11 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 136 253 253 253 212 135 132 16
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
import numpy as np
print('Digits: 0 1 2 3 4 5 6 7 8 9')
print('labels: %s' % np.unique(y))
print('Class distribution: %s' % np.bincount(y))
Digits: 0 1 2 3 4 5 6 7 8 9
labels: [0 1 2 3 4 5 6 7 8 9]
Class distribution: [5923 6742 5958 6131 5842 5421 5918 6265 5851 5949]
存储为 CSV 文件
np.savetxt(fname='images.csv',
X=X, delimiter=',', fmt='%d')
np.savetxt(fname='labels.csv',
X=y, delimiter=',', fmt='%d')
API
loadlocal_mnist(images_path, labels_path)
从 ubyte 文件读取 MNIST。
参数
-
images_path
: str测试或训练 MNIST ubyte 文件的路径
-
labels_path
: str测试或训练 MNIST 类别标签文件的路径
返回值
-
images
: [n_samples, n_pixels] numpy.array图像的像素值。
-
labels
: [n_samples] numpy array目标类别标签
示例
有关用法示例,请参阅 https://mlxtend.cn/mlxtend/user_guide/data/loadlocal_mnist/