pytorch如何实现多个Dataloader同时训练

小编给大家分享一下pytorch如何实现多个Dataloader同时训练,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

看代码吧~

pytorch如何实现多个Dataloader同时训练

如果两个dataloader的长度不一样,那就加个:

fromitertoolsimportcycle

仅使用zip,迭代器将在长度等于最小数据集的长度时耗尽。 但是,使用cycle时,我们将再次重复最小的数据集,除非迭代器查看最大数据集中的所有样本。

pytorch如何实现多个Dataloader同时训练

补充:pytorch技巧:自定义数据集 torch.utils.data.DataLoader 及Dataset的使用

本博客中有可直接运行的例子,便于直观的理解,在torch环境中运行即可。

1. 数据传递机制

在 pytorch 中数据传递按一下顺序:

1、创建 datasets ,也就是所需要读取的数据集。

2、把 datasets 传入DataLoader。

3、DataLoader迭代产生训练数据提供给模型。

2. torch.utils.data.Dataset

Pytorch提供两种数据集:

Map式数据集 Iterable式数据集。其中Map式数据集继承torch.utils.data.Dataset,Iterable式数据集继承torch.utils.data.IterableDataset。

本文只介绍 Map式数据集。

一个Map式的数据集必须要重写 __getitem__(self, index)、 __len__(self) 两个方法,用来表示从索引到样本的映射(Map)。 __getitem__(self, index)按索引映射到对应的数据, __len__(self)则会返回这个数据集的长度。

基本格式如下:

importtorch.utils.dataasdata
classVOCDetection(data.Dataset):
'''
必须继承data.Dataset类
'''
def__init__(self):
'''
在这里进行初始化,一般是初始化文件路径或文件列表
'''
pass
def__getitem__(self,index):
'''
1.按照index,读取文件中对应的数据(读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
2.对读取到的数据进行数据增强(数据增强是深度学习中经常用到的,可以提高模型的泛化能力)
3.返回数据对(一般我们要返回图片,对应的标签)在这里因为我没有写完整的代码,返回值用0代替
'''
return0
def__len__(self):
'''
返回数据集的长度
'''
return0

可直接运行的例子:

importtorch.utils.dataasdata
importnumpyasnp
x=np.array(range(80)).reshape(8,10)#模拟输入,8个样本,每个样本长度为10
y=np.array(range(8))#模拟对应样本的标签,8个标签
classMydataset(data.Dataset):
def__init__(self,x,y):
self.x=x
self.y=y
self.idx=list()
foriteminx:
self.idx.append(item)
pass
def__getitem__(self,index):
input_data=self.idx[index]#可继续进行数据增强,这里没有进行数据增强操作
target=self.y[index]
returninput_data,target
def__len__(self):
returnlen(self.idx)
datasets=Mydataset(x,y)#初始化
print(datasets.__len__())#调用__len__()返回数据的长度
foriinrange(len(y)):
input_data,target=datasets.__getitem__(i)#调用__getitem__(index)返回读取的数据对
print('input_data%d='%i,input_data)
print('target%d='%i,target)

结果如下:

pytorch如何实现多个Dataloader同时训练

3. torch.utils.data.DataLoader

PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。

torch.utils.data.DataLoader(onject)的可用参数如下:

1.dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。

2.batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)

3.shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)

4.sampler (Sampler, optional):从数据集中提取样本的策略。如果指定,“shuffle”必须为false。我没有用过,不太了解。

5.batch_sampler (Sampler, optional):和batch_size、shuffle等参数互斥,一般用默认。

6.num_workers:这个参数必须大于等于0,为0时默认使用主线程读取数据,其他大于0的数表示通过多个进程来读取数据,可以加快数据读取速度,一般设置为2的N次方,且小于batch_size(默认:0)

7.collate_fn (callable, optional): 合并样本清单以形成小批量。用来处理不同情况下的输入dataset的封装。

8.pin_memory (bool, optional):如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存中.

9.drop_last (bool, optional): 如果数据集大小不能被批大小整除,则设置为“true”以除去最后一个未完成的批。如果“false”那么最后一批将更小。(默认:false)

10.timeout(numeric, optional):设置数据读取时间限制,超过这个时间还没读取到数据的话就会报错。(默认:0)

11.worker_init_fn (callable, optional): 每个worker初始化函数(默认:None)

可直接运行的例子:

importtorch.utils.dataasdata
importnumpyasnp
x=np.array(range(80)).reshape(8,10)#模拟输入,8个样本,每个样本长度为10
y=np.array(range(8))#模拟对应样本的标签,8个标签
classMydataset(data.Dataset):
def__init__(self,x,y):
self.x=x
self.y=y
self.idx=list()
foriteminx:
self.idx.append(item)
pass
def__getitem__(self,index):
input_data=self.idx[index]
target=self.y[index]
returninput_data,target
def__len__(self):
returnlen(self.idx)
if__name__==('__main__'):
datasets=Mydataset(x,y)#初始化
dataloader=data.DataLoader(datasets,batch_size=4,num_workers=2)
fori,(input_data,target)inenumerate(dataloader):
print('input_data%d'%i,input_data)
print('target%d'%i,target)

结果如下:(注意看类别,DataLoader把数据封装为Tensor)

pytorch如何实现多个Dataloader同时训练

以上是“pytorch如何实现多个Dataloader同时训练”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注恰卡编程网行业资讯频道!

发布于 2021-05-30 14:10:06
收藏
分享
海报
0 条评论
172
上一篇:ThingJS粒子特效如何实现雨雪效果 下一篇:android控件Banner如何实现简单轮播图效果
目录

    0 条评论

    本站已关闭游客评论,请登录或者注册后再评论吧~

    忘记密码?

    图形验证码