前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow | MNIST手写字识别

Tensorflow | MNIST手写字识别

作者头像
努力在北京混出人样
发布2019-02-18 15:52:57
1.5K0
发布2019-02-18 15:52:57
举报
文章被收录于专栏:祥子的故事

这次对最近学习tensorflow的总结,以理解MNIST手写字识别案例为例来说明

原始的网址:https://www.tensorflow.org/versions/r0.12/tutorials/mnist/beginners/index.html#mnist-for-ml-beginners

0、数据解释

数据为图片,每个图片是28像素*28像素,带有标签,类似于X和Y,X为28像素*28像素的数据,Y为该图片的真实数字,即标签。

1、数据的处理 以一个图片为例

先转为上图右方的矩阵,然后将矩阵摊平为1*784的向量,28*28 = 784,这里的X有784个特征。

标签数据为0-9这10个数,为了方便处理,也将数据向量化,例如,3处理为[0,0,0,1,0,0,0,0,0,0][0,0,0,1,0,0,0,0,0,0],5处理为[0,0,0,0,1,0,0,0,0,0][0,0,0,0,1,0,0,0,0,0]

2、数据的读入

用代码来下载数据并读取

代码语言:javascript
复制
#加载tensorflow包 
import tensorflow as tf 
#加载读取函数
from tensorflow.examples.tutorials.mnist import input_data 
#读数据,one_hot表示将矩阵处理为行向量,即28*28 => 1*784
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. Extracting MNIST_data/train-images-idx3-ubyte.gz Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. Extracting MNIST_data/train-labels-idx1-ubyte.gz Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. Extracting MNIST_data/t10k-images-idx3-ubyte.gz Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

表示下载成功了。

训练数据集有55,000 条,即X为55,000 * 784的矩阵,那么Y为55,000 * 10的矩阵

3、模型 这里采用softmax 回归的方法,下面介绍softmax模型:

我们知道MNIST的每一张图片都表示一个数字,从0到9。我们希望得到给定图片代表每个数字的概率。比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是8的概率是5%(因为8和9都有上半部分的小圆),然后给予它代表其他数字的概率更小的值。

这是一个使用softmax回归(softmax regression)模型的经典案例。softmax模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。

softmax回归(softmax regression)分两步:第一步

为了得到一张给定图片属于某个特定数字类的证据(evidence),我们对图片像素值进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负数,相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。

我们也需要加入一个额外的偏置量(bias),因为输入往往会带有一些无关的干扰量。因此对于给定的输入图片 x 它代表的是数字 i 的证据可以表示为

evidencei=∑jWi,jxj+bi

evidence_{i} = \sum_{j} W_{i,j} x_{j} + b_{i}

其中 WiW_{i} 代表权重, bib_{i}代表数字 ii类的偏置量,jj 代表给定图片 x 的像素索引用于像素求和。然后用softmax函数可以把这些证据转换成概率 y:

y=softmax(evidence)

y = softmax(evidence)

这里的softmax可以看成是一个激励(activation)函数或者链接(link)函数,把我们定义的线性函数的输出转换成我们想要的格式,也就是关于10个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转换成为一个概率值。softmax函数可以定义为:

softmax(x)=normalize(exp(x))

softmax(x)=normalize(exp(x))

即,

softmax(x)=exp(xi)∑jexp(xj)

softmax(x) = \frac{exp(x_{i})}{\sum_{j} exp(x_{j})}

这样得到的结果便是概率,从而获取了是0-9这10个数的概率,然后比较概率的大小,概率最大的即为模型得到的结果类别。

图示softmax模型:

转为矩阵的形式:

进而,有:

简化为:

y=softmax(Wx+b)

y = softmax(Wx + b)

4、代码实现

  • 定义变量

这里需要预定义

代码语言:javascript
复制
#定义X,浮点型,784列,None表示存在但为空值
#placeholder是占位符,
x = tf.placeholder("float", [None, 784])

预定义参数W和b

代码语言:javascript
复制
#定义W,W为矩阵,784*10的矩阵
#Variable 表示可修改的张量
W = tf.Variable(tf.zeros([784,10]))
代码语言:javascript
复制
#预定义b,b矩阵,1*10的矩阵
b = tf.Variable(tf.zeros([10]))

注意,WW的维度是[784,10],因为我们想要用784维的图片向量乘以它以得到一个10维的证据值向量,每一位对应不同数字类。bb的形状是[10],所以我们可以直接把它加到输出上面。

模型的实现:

代码语言:javascript
复制
y = tf.nn.softmax(tf.matmul(x,W) + b)

首先,我们用tf.matmul(??X,W)表示xx乘以WW,对应之前等式里面的WxWx,这里xx是一个2维张量拥有多个输入。然后再加上bb,把和输入到tf.nn.softmax函数里面。

5、训练模型

为了训练我们的模型,我们首先需要定义一个指标来评估这个模型是好的。其实,在机器学习,我们通常定义指标来表示一个模型是坏的,这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。但是,这两种方式是相同的。

一个非常常见的,非常漂亮的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。它的定义如下:

Hy′(y)=?∑iy′ilog(yi)

H_{y^{'}}(y) = - \sum_{i}y^{'}_{i} log(y_{i})

yy 是我们预测的概率分布, y′y^{'}是实际的分布(我们输入的one-hot vector)。比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。

为了计算交叉熵,我们首先需要添加一个新的占位符用于输入正确值,定义为y_y\_ :

代码语言:javascript
复制
#为10列的矩阵
y_ = tf.placeholder("float", [None,10])

然后用?∑y′log(y)-\sum y^{'} log(y) 来计算交叉熵:

代码语言:javascript
复制
#交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

用梯度下降法来训练模型:

TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动。

代码语言:javascript
复制
#基于交叉熵值最小
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

需要添加一个操作来初始化我们创建的变量:

代码语言:javascript
复制
#初始化
init = tf.initialize_all_variables()

在一个Session里面启动我们的模型,并且初始化变量:

代码语言:javascript
复制
sess = tf.Session()
sess.run(init)

开始训练模型,这里我们让模型循环训练1000次!

代码语言:javascript
复制
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行train_step

6、评估模型 tf.argmax(y,1),返回的是模型对于任一输入x预测的标签值为1的索引值。tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。 tf.equal : 检测我们的预测是否真实标签匹配(索引位置一样表示匹配)

代码语言:javascript
复制
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

精度:

代码语言:javascript
复制
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

最后,我们计算所学习到的模型在测试数据集上面的正确率。

代码语言:javascript
复制
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

====================== 上面便是完整的思路之一,下面给出完整的代码:

代码语言:javascript
复制
#加载包
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#读数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#预定义X
x = tf.placeholder("float", [None, 784])
#预定义W
W = tf.Variable(tf.zeros([784,10]))
#预定义b
b = tf.Variable(tf.zeros([10]))
#模型
y = tf.nn.softmax(tf.matmul(x,W) + b)
#预定义真实值
y_ = tf.placeholder("float", [None,10])
#交叉熵损失函数
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#基于梯度的优化
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#初始化模型的所有变量
init = tf.global_variables_initializer()
#启动一个模型,并初始化
sess = tf.Session()
sess.run(init)
#开始训练模型,训练1000次,每次抽取100个批数据
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#真实与预测
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
#准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#输出结果
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 0.9123

准确率为0.9123

本文参与?腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017年01月15日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客?前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体同步曝光计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com