pytorch中DataLoader()过程中会遇到的问题有哪些

这篇文章将为大家详细讲解有关pytorch中DataLoader()过程中会遇到的问题有哪些,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。

如下所示:

RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2

pytorch中DataLoader()过程中会遇到的问题有哪些

train_dataset=datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.Resize((224))###

原因是

transforms.Resize() 的参数设置问题,改为如下设置就可以了

train_dataset=datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.Resize((224,224)),

同理,val_dataset中也调整为transforms.Resize((224,224))。

补充:pytorch之dataloader深入剖析

- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存

② Queue的特点

当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。

当数据满了: queue.put() 会阻塞

③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展

输入数据PipeLine

pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象

② 创建一个 DataLoader 对象

③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

dataset=MyDataset()
dataloader=DataLoader(dataset)
num_epoches=100
forepochinrange(num_epoches):
forimg,labelindataloader:
....

所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

1.DataLoader

先介绍一下DataLoader(object)的参数:

dataset(Dataset): 传入的数据集

batch_size(int, optional): 每个batch有多少个样本

shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…

如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each

workersubprocesswiththeworkerid(anintin[0,num_workers-1])as
input,afterseedingandbeforedataloading.(default:None)

- 首先dataloader初始化时得到datasets的采样list

classDataLoader(object):
r"""
Dataloader.Combinesadatasetandasampler,andprovides
single-ormulti-processiteratorsoverthedataset.
Arguments:
dataset(Dataset):datasetfromwhichtoloadthedata.
batch_size(int,optional):howmanysamplesperbatchtoload
(default:1).
shuffle(bool,optional):setto``True``tohavethedatareshuffled
ateveryepoch(default:False).
sampler(Sampler,optional):definesthestrategytodrawsamplesfrom
thedataset.Ifspecified,``shuffle``mustbeFalse.
batch_sampler(Sampler,optional):likesampler,butreturnsabatchof
indicesatatime.Mutuallyexclusivewithbatch_size,shuffle,
sampler,anddrop_last.
num_workers(int,optional):howmanysubprocessestousefordata
loading.0meansthatthedatawillbeloadedinthemainprocess.
(default:0)
collate_fn(callable,optional):mergesalistofsamplestoformamini-batch.
pin_memory(bool,optional):If``True``,thedataloaderwillcopytensors
intoCUDApinnedmemorybeforereturningthem.
drop_last(bool,optional):setto``True``todropthelastincompletebatch,
ifthedatasetsizeisnotdivisiblebythebatchsize.If``False``and
thesizeofdatasetisnotdivisiblebythebatchsize,thenthelastbatch
willbesmaller.(default:False)
timeout(numeric,optional):ifpositive,thetimeoutvalueforcollectingabatch
fromworkers.Shouldalwaysbenon-negative.(default:0)
worker_init_fn(callable,optional):IfnotNone,thiswillbecalledoneach
workersubprocesswiththeworkerid(anintin``[0,num_workers-1]``)as
input,afterseedingandbeforedataloading.(default:None)
..note::Bydefault,eachworkerwillhaveitsPyTorchseedsetto
``base_seed+worker_id``,where``base_seed``isalonggenerated
bymainprocessusingitsRNG.However,seedsforotherlibraies
maybeduplicateduponinitializingworkers(w.g.,NumPy),causing
eachworkertoreturnidenticalrandomnumbers.(See
:ref:`dataloader-workers-random-seed`sectioninFAQ.)Youmay
use``torch.initial_seed()``toaccessthePyTorchseedforeach
workerin:attr:`worker_init_fn`,anduseittosetotherseeds
beforedataloading.
..warning::If``spawn``startmethodisused,:attr:`worker_init_fn`cannotbean
unpicklableobject,e.g.,alambdafunction.
"""
__initialized=False
def__init__(self,dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,
num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False,
timeout=0,worker_init_fn=None):
self.dataset=dataset
self.batch_size=batch_size
self.num_workers=num_workers
self.collate_fn=collate_fn
self.pin_memory=pin_memory
self.drop_last=drop_last
self.timeout=timeout
self.worker_init_fn=worker_init_fn
iftimeout<0:
raiseValueError('timeoutoptionshouldbenon-negative')
ifbatch_samplerisnotNone:
ifbatch_size>1orshuffleorsamplerisnotNoneordrop_last:
raiseValueError('batch_sampleroptionismutuallyexclusive'
'withbatch_size,shuffle,sampler,and'
'drop_last')
self.batch_size=None
self.drop_last=None
ifsamplerisnotNoneandshuffle:
raiseValueError('sampleroptionismutuallyexclusivewith'
'shuffle')
ifself.num_workers<0:
raiseValueError('num_workersoptioncannotbenegative;'
'usenum_workers=0todisablemultiprocessing.')
ifbatch_samplerisNone:
ifsamplerisNone:
ifshuffle:
sampler=RandomSampler(dataset)//将list打乱
else:
sampler=SequentialSampler(dataset)
batch_sampler=BatchSampler(sampler,batch_size,drop_last)
self.sampler=sampler
self.batch_sampler=batch_sampler
self.__initialized=True
def__setattr__(self,attr,val):
ifself.__initializedandattrin('batch_size','sampler','drop_last'):
raiseValueError('{}attributeshouldnotbesetafter{}is'
'initialized'.format(attr,self.__class__.__name__))
super(DataLoader,self).__setattr__(attr,val)
def__iter__(self):
return_DataLoaderIter(self)
def__len__(self):
returnlen(self.batch_sampler)

其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!

classRandomSampler(Sampler):
r"""Sampleselementsrandomly,withoutreplacement.
Arguments:
data_source(Dataset):datasettosamplefrom
"""
def__init__(self,data_source):
self.data_source=data_source
def__iter__(self):
returniter(torch.randperm(len(self.data_source)).tolist())
def__len__(self):
returnlen(self.data_source)
classBatchSampler(Sampler):
r"""Wrapsanothersamplertoyieldamini-batchofindices.
Args:
sampler(Sampler):Basesampler.
batch_size(int):Sizeofmini-batch.
drop_last(bool):If``True``,thesamplerwilldropthelastbatchif
itssizewouldbelessthan``batch_size``
Example:
>>>list(BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=False))
[[0,1,2],[3,4,5],[6,7,8],[9]]
>>>list(BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=True))
[[0,1,2],[3,4,5],[6,7,8]]
"""
def__init__(self,sampler,batch_size,drop_last):
ifnotisinstance(sampler,Sampler):
raiseValueError("samplershouldbeaninstanceof"
"torch.utils.data.Sampler,butgotsampler={}"
.format(sampler))
ifnotisinstance(batch_size,_int_classes)orisinstance(batch_size,bool)or\
batch_size<=0:
raiseValueError("batch_sizeshouldbeapositiveintegeralvalue,"
"butgotbatch_size={}".format(batch_size))
ifnotisinstance(drop_last,bool):
raiseValueError("drop_lastshouldbeabooleanvalue,butgot"
"drop_last={}".format(drop_last))
self.sampler=sampler
self.batch_size=batch_size
self.drop_last=drop_last
def__iter__(self):
batch=[]
foridxinself.sampler:
batch.append(idx)
iflen(batch)==self.batch_size:
yieldbatch
batch=[]
iflen(batch)>0andnotself.drop_last:
yieldbatch
def__len__(self):
ifself.drop_last:
returnlen(self.sampler)//self.batch_size
else:
return(len(self.sampler)+self.batch_size-1)//self.batch_size

- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;

- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的

__getitem__()方法

- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;

class_DataLoaderIter(object):
r"""IteratesonceovertheDataLoader'sdataset,asspecifiedbythesampler"""
def__init__(self,loader):
self.dataset=loader.dataset
self.collate_fn=loader.collate_fn
self.batch_sampler=loader.batch_sampler
self.num_workers=loader.num_workers
self.pin_memory=loader.pin_memoryandtorch.cuda.is_available()
self.timeout=loader.timeout
self.done_event=threading.Event()
self.sample_iter=iter(self.batch_sampler)
base_seed=torch.LongTensor(1).random_().item()
ifself.num_workers>0:
self.worker_init_fn=loader.worker_init_fn
self.index_queues=[multiprocessing.Queue()for_inrange(self.num_workers)]
self.worker_queue_idx=0
self.worker_result_queue=multiprocessing.SimpleQueue()
self.batches_outstanding=0
self.worker_pids_set=False
self.shutdown=False
self.send_idx=0
self.rcvd_idx=0
self.reorder_dict={}
self.workers=[
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset,self.index_queues[i],
self.worker_result_queue,self.collate_fn,base_seed+i,
self.worker_init_fn,i))
foriinrange(self.num_workers)]
ifself.pin_memoryorself.timeout>0:
self.data_queue=queue.Queue()
ifself.pin_memory:
maybe_device_id=torch.cuda.current_device()
else:
#donotinitializecudacontextifnotnecessary
maybe_device_id=None
self.worker_manager_thread=threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue,self.data_queue,self.done_event,self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon=True
self.worker_manager_thread.start()
else:
self.data_queue=self.worker_result_queue
forwinself.workers:
w.daemon=True#ensurethattheworkerexitsonprocessexit
w.start()
_update_worker_pids(id(self),tuple(w.pidforwinself.workers))
_set_SIGCHLD_handler()
self.worker_pids_set=True
#primetheprefetchloop
for_inrange(2*self.num_workers):
self._put_indices()
def__len__(self):
returnlen(self.batch_sampler)
def_get_batch(self):
ifself.timeout>0:
try:
returnself.data_queue.get(timeout=self.timeout)
exceptqueue.Empty:
raiseRuntimeError('DataLoadertimedoutafter{}seconds'.format(self.timeout))
else:
returnself.data_queue.get()
def__next__(self):
ifself.num_workers==0:#same-processloading
indices=next(self.sample_iter)#mayraiseStopIteration
batch=self.collate_fn([self.dataset[i]foriinindices])
ifself.pin_memory:
batch=pin_memory_batch(batch)
returnbatch
#checkifthenextsamplehasalreadybeengenerated
ifself.rcvd_idxinself.reorder_dict:
batch=self.reorder_dict.pop(self.rcvd_idx)
returnself._process_next_batch(batch)
ifself.batches_outstanding==0:
self._shutdown_workers()
raiseStopIteration
whileTrue:
assert(notself.shutdownandself.batches_outstanding>0)
idx,batch=self._get_batch()
self.batches_outstanding-=1
ifidx!=self.rcvd_idx:
#storeout-of-ordersamples
self.reorder_dict[idx]=batch
continue
returnself._process_next_batch(batch)
next=__next__#Python2compatibility
def__iter__(self):
returnself
def_put_indices(self):
assertself.batches_outstanding<2*self.num_workers
indices=next(self.sample_iter,None)
ifindicesisNone:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx,indices))
self.worker_queue_idx=(self.worker_queue_idx+1)%self.num_workers
self.batches_outstanding+=1
self.send_idx+=1
def_process_next_batch(self,batch):
self.rcvd_idx+=1
self._put_indices()
ifisinstance(batch,ExceptionWrapper):
raisebatch.exc_type(batch.exc_msg)
returnbatch
def_worker_loop(dataset,index_queue,data_queue,collate_fn,seed,init_fn,worker_id):
global_use_shared_memory
_use_shared_memory=True
#IntializeCsidesignalhandlersforSIGBUSandSIGSEGV.Pythonsignal
#module'shandlersareexecutedafterPythonreturnsfromClow-level
#handlers,likelywhenthesamefatalsignalhappenedagainalready.
#https://docs.python.org/3/library/signal.htmlSec.18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
ifinit_fnisnotNone:
init_fn(worker_id)
watchdog=ManagerWatchdog()
whileTrue:
try:
r=index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
exceptqueue.Empty:
ifwatchdog.is_alive():
continue
else:
break
ifrisNone:
break
idx,batch_indices=r
try:
samples=collate_fn([dataset[i]foriinbatch_indices])
exceptException:
data_queue.put((idx,ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx,samples))
delsamples

- 需要对队列操作,缓存数据,使得加载提速!

关于“pytorch中DataLoader()过程中会遇到的问题有哪些”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。

发布于 2021-05-30 14:08:28
收藏
分享
海报
0 条评论
185
上一篇:Spring Native项目的示例分析 下一篇:如何利用Python识别图片中的文字
目录

    0 条评论

    本站已关闭游客评论,请登录或者注册后再评论吧~

    忘记密码?

    图形验证码