前言
既然已经踏上了漫漫的炼丹之路,那么接下来的修行就必不可少。上一篇用一些实例来演示了 TensorFlow 是如何工作的,但那只是简单的回归而已,不能算是真正的炼丹。这一篇,我们就聊一聊几乎每个人都会用来作为 TensorFlow 实践的第一个项目的——手写数字识别。
MNIST 数据集
既然是训练一个能够识别手写数字的模型,那么手写数字的数据就必不可少。MNIST 数据集是一个几乎被每个炼丹师都用过的数据集了,每一个 TensorFlow 的教程一个都会对它下手。下面来介绍一下 MNIST 数据集。
MNIST 数据集在 http://yann.lecun.com/exdb/mnist/ 可以下载,来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集包含了 250 个不同人的手写数字,其中 50% 是高中学生,50% 来自人口普查局的工作人员。测试集的数据是同样的比例。
数据分为4个文件,分别是「训练集图像」、「训练集标签」、「测试集图像」、「测试集标签」
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
训练集(mnist.train)有 55000 行,验证集(mnist.validation)有 5000 行,测试集(mnist.test)有 10000 行。训练集的每一张图片(mnist.train.images)有 28 * 28 = 784 个像素点,即整个训练集可以转化为 [55000, 784] 的矩阵。
通过以下的简单操作就可以看到图片了:
1 | img = np.array(mnist.train.images[1]) |
上图就是训练集的前两张图片。图片长什么样子其实根本不重要啦,毕竟训练的时候不过都是些矩阵而已,最终也可以根据标签来判断准确率,没必要把图片都展示出来。
坯子
想得到一个能够识别手写数字的模型,我们先来一个没有中间层的吧!输入层 784 个节点之间连接到输出层的 10 个节点上(输出层的 10 个节点对应 10 个数字,选择数值最大的节点作为结果),这样的模型应该算得上是简单暴力了吧!当然,这样做的准确率一定是不高的,不过我们可以之后慢慢调整嘛。炼丹怎么会一次成功呢?
1 | import tensorflow as tf |
程序解释
- MNIST 数据集可以由 tensorflow.examples 下载,其中 “MNIST_data” 是放数据的目录名(没有该目录会创建目录),one_hot 指标签采用 [0,0,1,0,…] 这样的形式,结果是几,第几位的数字就是 1 ,其他 9 位都是 0 。
- batch_size 表示数据一批一批的传入,每一批 100 组数据
- softmax 函数可以将输出映射为 0~1 的值,正好对应概率,同时,该函数会隐藏比较小的数而放大比较大的数。更多关于 softmax 函数的介绍可以参考维基百科 。
程序结果
1 | Iter 0 , Training Accuracy 0.92636365 , Testing Accuracy 0.9262 |
上面是程序运行的结果,可以看到 93.15% 的测试集都识别成功了。但是这显然是不够的,我们还可以优化程序,让模型的正确率更高。
那么下一步,让我们修改模型,让预测的准确率(测试集)提升到 95% 以上吧,漫漫炼丹路就要开始咯~~
未完待续