前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >python多线程结合DataLoader加载数据

python多线程结合DataLoader加载数据

原创
作者头像
languageX
发布2021-09-14 21:13:28
2.7K0
发布2021-09-14 21:13:28
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

在模型训练过程中,通常大家都会将注意力集中在模型加速以及提升GPU使用率,但是有时我们的耗时瓶颈也会在读取数据上,gpu处理太快,反而cpu喂数据跟不上。当然框架也会提供一些数据读取加速方案,比如tensorflow的 tf.data.TFRecordDataset,pytorch的DataLoader使用num_workers参数内部采用多线程方案等,还有些代码是将所有数据制作到一个二进制文件读入内存,然后从内存中快速读取数据,但是这种方案无法处理大数据项目。

tensorflow的record也需要先生成record文件格式然后读取,pytorch的DataLoader在设置num_workers时特别在windows中有些版本设置为非0会存在一些问题,本文介绍自己使用python的多线程来处理数据的一种方案,然后结合pytorch的Dataset和DataLoader获取数据,供大家参考。

一 创建buffer类

先建立一个buffer类,其中读写数据需要使用两个锁

代码语言:txt
复制
import threading
import random

class Buffer:
    def __init__(self, size):
        self.size = size
        self.buffer = []
        self.lock = threading.Lock()
        self.has_data = threading.Condition(self.lock)
        self.has_pos = threading.Condition(self.lock)

    def get_size(self):
        return self.size

    def get(self):
        with self.has_data:
            while len(self.buffer) == 0:
                self.has_data.wait()
            result = self.buffer[0]
            # print("get buffer size", len(self.buffer))
            del self.buffer[0]
            self.has_pos.notify_all()
        return result

    def put(self, data):
        with self.has_pos:
            while len(self.buffer) >= self.size:
                self.has_pos.wait()
            self.buffer.append(data)
            self.has_data.notify_all()

# test
def get():
    while True:
        get_data = buffer.get()
# test
def put():
    while True:
        data = random.randint(0, 9)
        buffer.put(a)

buffer类参考:/developer/article/1724559

二 创建Dataset

生成一个DataReader创建多线程写数据,以及单线程读数据。以下为多线程的关键代码

代码语言:txt
复制
class DataReader:
    def __init__(self, max_buffer_size=5000):
        self.audio_files = files_to_list(training_files)
        random.shuffle(self.audio_files)
        self.buffer = Buffer(max_buffer_size)
    # 消费数据
    def comsume(self):
        while True:
            result = self.buffer.get()
    # 生产数据 
    def produce(self):
        while True:
            global index
            index += 1
            if index >= len(self.audio_files)-1:
                index = 0
            start = time.time()
            file = self.audio_files[index]
            audio = load_wav(file)
            end = time.time()
            self.buffer.put(audio)

    def run_produce(self, thread_num=16):
        # 多线程生产
        for _ in range(thread_num):
            th = threading.Thread(target=self.produce)
            th.start()

    def get_item(self, index):
        result = self.buffer.get()
        return result
       

下面使用一个Dataset来使用DataReader获取数据

代码语言:txt
复制
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data_reader = DataReader()
        self.data_reader.run_produce()
        
    def __getitem__(self, index):
        # 从buffer中获取一个数据
        start = time.time()
        audio = self.data_reader.get_item(index)
        # 进行数据处理
        ...
        audio = torch.from_numpy(audio).float()
        end = time.time()
        # print("get item time cost", (end - start) * 1000, audio.shape)
        return audio.unsqueeze(0)
    def __len__(self):
        return len(self.audio_files)

三 创建DataLoader

最后就可以通过DataLoader从DataSet中循环获取batch数据输入到模型进行训练了

代码语言:python
复制
dataset = AudioDataset()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一 创建buffer类
  • 二 创建Dataset
  • 三 创建DataLoader
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com