首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Tensorflow hello world<一>

1、下载mnist数据集(三大特征)

1.1 包含训练集(55000)、验证集(5000)和测试集(10000)

1.2 每张图片都是28*28维的向量

1.3 label是个10维的热编码向量(one_hot)

数据集shape打印结果如下所示

(55000, 784)

(55000, 10)

(5000, 784)

(5000, 10)

(10000, 784)

(10000, 10)

2、建立分类模型及检测模型准确度

使用前几天讲过的Tensorflow基础建立初步模型~

fromtensorflow.examples.tutorials.mnistimportinput_data

importtensorflowastf

#下载数据集

mnist=input_data.read_data_sets("mnist_data/",one_hot=True)

#定义输入数据占位符

train_input=tf.placeholder(tf.float32,[None,784])

train_label=tf.placeholder(tf.float32,[None,10])

#定义初始化参数

W=tf.Variable(tf.random_normal([784,10],stddev=0.1))

b=tf.Variable(tf.zeros([10]))

#定义训练模型

pred_y=tf.nn.softmax(tf.matmul(train_input,W)+b)

#定义损失函数

cross_entropy=tf.reduce_mean(-tf.reduce_sum(train_label*tf.log(pred_y)))

#定义损失函数的优化方法

train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

#建立会话

sess=tf.Session()

#初始化变量

sess.run(tf.global_variables_initializer())

#训练数据集

foriinrange(1000):

batch_xs,batch_ys=mnist.train.next_batch(128)

sess.run(train_step,feed_dict={train_input:batch_xs,train_label:batch_ys})

#计算模型准确性

correct_pre=tf.equal(tf.argmax(train_label,1),tf.argmax(pred_y,1))

acc=tf.reduce_mean(tf.cast(correct_pre,tf.float32))

print(sess.run(acc,feed_dict={train_input:mnist.test.images,train_label:mnist.test.labels}))

准确率91%左右

课后作业:卷积神经网络的数学基础

动手试一试上面的小示例,每天学习一点点~

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180531G037SH00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券
http://www.vxiaotou.com