前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow读取数据(一)

Tensorflow读取数据(一)

作者头像
languageX
发布2021-01-26 22:41:16
1.1K0
发布2021-01-26 22:41:16
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

数据算法是深度学习最重要的两大块。而更基础的首先是要熟练掌握一个框架来支撑算法的执行。 我个人使用最多的是tensorflow平台。就从最基础的数据输入开始记录吧。

AI算法基本流程

个人总结的AI项目基础流程(除开更复杂的工程化工作) (1)数据预处理:get每个迭代的输入和标签。图像,音频,文本对数据处理方式又各有不同;不同的需求对标签的格式也不相同。 (2)算法建模:设计网络模型,输入:训练数据;输出:预测值 (3)优化参数:通过输出和真实label设计loss,还需要设计一个优化算法,让网络参数去学习得到最优解 (4)迭代训练:不断更新数据,在大数据上优化参数 (5)保存网络参数以及设计评价指标 以上步骤还只是算法部分,而且每个模块都有很可以展开出很多内容,其他更多工程上模块就不提了~

数据模块

今天先从数据模块下手。在训练过程中,我们对需求就是要不断的从所有数据中取一个batch数据输入到模型中。如果是python,那比较简单,伪代码如下:

代码语言:javascript
复制
#随机从datas里面抽取batch_size个数据
def get_batch(batch_size,datas):
batch_datas = []
datas.shuffle()
for i in range(batch_size):
	batch_datas.append(datas[i])
return batch_datas

但是在tensorflow框架中,我们就要利用它的优势来进行数据的读取。今天先介绍通过tf.Coordinatortf.QueueRunner来利用多线程管理数据。 tf.QueueRunner()就是负责开启线程以及线程队列 tf.train.Coordinator()就是创建一个线程管理器,管理我们开启的线程

准备数据

我们先准备两类图片数据,结构如下

在这里插入图片描述
在这里插入图片描述

为了方便,我们建立数据集文件夹Images,里面两类图片数据1,2。 然后我们生成一个文件列表,代码如下:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:35
# @Author  : LanguageX
import os

root_dir = os.getcwd()
fw = open("./train.txt","w")
for root, dirs, files in os.walk(root_dir):
    for file in files:
        if file.endswith("jpg") or file.endswith("png"):
            filename = os.path.join(root, file)
            class_name = filename.split("/")[-2]
            print(class_name,filename)
            fw.write(filename+" "+class_name+"\n")

目的就是生成train.txt文本列表(格式:图片路径–类别)

在这里插入图片描述
在这里插入图片描述

数据准备好了~下面就可以开始实现取数据的代码了~

代码框架比较简单,添加了比较详细的注释,就直接上代码吧:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
# @Time    : 2019-09-21 22:24
# @Author  : LanguageX

import tensorflow as tf
import os

class DataReader:

    def get_data_lines(self, filename):
        with open(filename) as txt_file:
            lines = txt_file.readlines()
            return lines

    def gen_datas(self, train_files):
        paths = []
        labels = []
        for line in train_files:
            line = line.replace("\n","")
            path, label = line.split(" ")
            paths.append(path)
            labels.append(label)
        return paths, labels

    def __init__(self,root_dir,train_filepath,batch_size,img_size):
         self.dir = root_dir
         self.batch_size = batch_size
         self.img_size = img_size
         #读取生成的path-label列表
         self.train_files = self.get_data_lines(train_filepath)
         #获取对应的paths和labels
         self.paths,self.labels = self.gen_datas(self.train_files)
         self.data_nums  = len(self.train_files)



    def get_batch(self, batch_size):
        self.paths = tf.cast(self.paths, tf.string)
        self.labels = tf.cast(self.labels, tf.string)
        #slice_input_producer构建了取数据队列
        input_queue = tf.train.slice_input_producer([self.paths, self.labels], num_epochs=10, shuffle=True)

        # 从文件名称队列中读取文件放入文件队列
        image_batch, label_batch= tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64,
                                                  allow_smaller_final_batch=False)

        return image_batch, label_batch



if __name__ == '__main__':
    root_dir = "../images/"
    filename = "./images/train.txt"
    batch_size = 4
    image_size = 256
    dataset = DataReader(root_dir,filename,batch_size,image_size)

    images,labels = dataset.get_batch(batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        #coord线程管理器
        coord = tf.train.Coordinator()
        #tf的线程队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(5):
            _imgs,_labesl = sess.run([images,labels])
            print("_imgs ", _imgs)
            print("_labes ", _labesl)
        #通知线程停止
        coord.request_stop()
        coord.join(threads)
        sess.close()

运行就可以在每个迭代获取到batch_size个数据了。基本本文获取数据的基本框架,其他任务的数据读取都可以举一反三添加业务需求了~

本文参与?腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-10-09 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客?前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • AI算法基本流程
  • 数据模块
    • 准备数据
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
    http://www.vxiaotou.com