pytorch中DataLoader()过程中会遇到的问题有哪些
这篇文章将为大家详细讲解有关pytorch中DataLoader()过程中会遇到的问题有哪些,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
如下所示:
RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2
train_dataset=datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize((224))###
原因是
transforms.Resize() 的参数设置问题,改为如下设置就可以了
train_dataset=datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize((224,224)),
同理,val_dataset中也调整为transforms.Resize((224,224))。
补充:pytorch之dataloader深入剖析
- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;
- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;
- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;
- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;
① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
② Queue的特点
当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
当数据满了: queue.put() 会阻塞
③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
输入数据PipeLine
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
dataset=MyDataset() dataloader=DataLoader(dataset) num_epoches=100 forepochinrange(num_epoches): forimg,labelindataloader: ....
所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。
首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.DataLoader
先介绍一下DataLoader(object)的参数:
dataset(Dataset)
: 传入的数据集
batch_size(int, optional)
: 每个batch有多少个样本
shuffle(bool, optional)
: 在每个epoch开始的时候,对数据进行重新排序
sampler(Sampler, optional)
: 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
batch_sampler(Sampler, optional)
: 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional)
: 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
collate_fn (callable, optional)
: 将一个list的sample组成一个mini-batch的函数
pin_memory (bool, optional)
: 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
drop_last (bool, optional)
: 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
timeout(numeric, optional)
: 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
worker_init_fn (callable, optional)
: 每个worker初始化函数 If not None, this will be called on each
workersubprocesswiththeworkerid(anintin[0,num_workers-1])as input,afterseedingandbeforedataloading.(default:None)
- 首先dataloader初始化时得到datasets的采样list
classDataLoader(object): r""" Dataloader.Combinesadatasetandasampler,andprovides single-ormulti-processiteratorsoverthedataset. Arguments: dataset(Dataset):datasetfromwhichtoloadthedata. batch_size(int,optional):howmanysamplesperbatchtoload (default:1). shuffle(bool,optional):setto``True``tohavethedatareshuffled ateveryepoch(default:False). sampler(Sampler,optional):definesthestrategytodrawsamplesfrom thedataset.Ifspecified,``shuffle``mustbeFalse. batch_sampler(Sampler,optional):likesampler,butreturnsabatchof indicesatatime.Mutuallyexclusivewithbatch_size,shuffle, sampler,anddrop_last. num_workers(int,optional):howmanysubprocessestousefordata loading.0meansthatthedatawillbeloadedinthemainprocess. (default:0) collate_fn(callable,optional):mergesalistofsamplestoformamini-batch. pin_memory(bool,optional):If``True``,thedataloaderwillcopytensors intoCUDApinnedmemorybeforereturningthem. drop_last(bool,optional):setto``True``todropthelastincompletebatch, ifthedatasetsizeisnotdivisiblebythebatchsize.If``False``and thesizeofdatasetisnotdivisiblebythebatchsize,thenthelastbatch willbesmaller.(default:False) timeout(numeric,optional):ifpositive,thetimeoutvalueforcollectingabatch fromworkers.Shouldalwaysbenon-negative.(default:0) worker_init_fn(callable,optional):IfnotNone,thiswillbecalledoneach workersubprocesswiththeworkerid(anintin``[0,num_workers-1]``)as input,afterseedingandbeforedataloading.(default:None) ..note::Bydefault,eachworkerwillhaveitsPyTorchseedsetto ``base_seed+worker_id``,where``base_seed``isalonggenerated bymainprocessusingitsRNG.However,seedsforotherlibraies maybeduplicateduponinitializingworkers(w.g.,NumPy),causing eachworkertoreturnidenticalrandomnumbers.(See :ref:`dataloader-workers-random-seed`sectioninFAQ.)Youmay use``torch.initial_seed()``toaccessthePyTorchseedforeach workerin:attr:`worker_init_fn`,anduseittosetotherseeds beforedataloading. ..warning::If``spawn``startmethodisused,:attr:`worker_init_fn`cannotbean unpicklableobject,e.g.,alambdafunction. """ __initialized=False def__init__(self,dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None, num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False, timeout=0,worker_init_fn=None): self.dataset=dataset self.batch_size=batch_size self.num_workers=num_workers self.collate_fn=collate_fn self.pin_memory=pin_memory self.drop_last=drop_last self.timeout=timeout self.worker_init_fn=worker_init_fn iftimeout<0: raiseValueError('timeoutoptionshouldbenon-negative') ifbatch_samplerisnotNone: ifbatch_size>1orshuffleorsamplerisnotNoneordrop_last: raiseValueError('batch_sampleroptionismutuallyexclusive' 'withbatch_size,shuffle,sampler,and' 'drop_last') self.batch_size=None self.drop_last=None ifsamplerisnotNoneandshuffle: raiseValueError('sampleroptionismutuallyexclusivewith' 'shuffle') ifself.num_workers<0: raiseValueError('num_workersoptioncannotbenegative;' 'usenum_workers=0todisablemultiprocessing.') ifbatch_samplerisNone: ifsamplerisNone: ifshuffle: sampler=RandomSampler(dataset)//将list打乱 else: sampler=SequentialSampler(dataset) batch_sampler=BatchSampler(sampler,batch_size,drop_last) self.sampler=sampler self.batch_sampler=batch_sampler self.__initialized=True def__setattr__(self,attr,val): ifself.__initializedandattrin('batch_size','sampler','drop_last'): raiseValueError('{}attributeshouldnotbesetafter{}is' 'initialized'.format(attr,self.__class__.__name__)) super(DataLoader,self).__setattr__(attr,val) def__iter__(self): return_DataLoaderIter(self) def__len__(self): returnlen(self.batch_sampler)
其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!
classRandomSampler(Sampler): r"""Sampleselementsrandomly,withoutreplacement. Arguments: data_source(Dataset):datasettosamplefrom """ def__init__(self,data_source): self.data_source=data_source def__iter__(self): returniter(torch.randperm(len(self.data_source)).tolist()) def__len__(self): returnlen(self.data_source)
classBatchSampler(Sampler): r"""Wrapsanothersamplertoyieldamini-batchofindices. Args: sampler(Sampler):Basesampler. batch_size(int):Sizeofmini-batch. drop_last(bool):If``True``,thesamplerwilldropthelastbatchif itssizewouldbelessthan``batch_size`` Example: >>>list(BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=False)) [[0,1,2],[3,4,5],[6,7,8],[9]] >>>list(BatchSampler(SequentialSampler(range(10)),batch_size=3,drop_last=True)) [[0,1,2],[3,4,5],[6,7,8]] """ def__init__(self,sampler,batch_size,drop_last): ifnotisinstance(sampler,Sampler): raiseValueError("samplershouldbeaninstanceof" "torch.utils.data.Sampler,butgotsampler={}" .format(sampler)) ifnotisinstance(batch_size,_int_classes)orisinstance(batch_size,bool)or\ batch_size<=0: raiseValueError("batch_sizeshouldbeapositiveintegeralvalue," "butgotbatch_size={}".format(batch_size)) ifnotisinstance(drop_last,bool): raiseValueError("drop_lastshouldbeabooleanvalue,butgot" "drop_last={}".format(drop_last)) self.sampler=sampler self.batch_size=batch_size self.drop_last=drop_last def__iter__(self): batch=[] foridxinself.sampler: batch.append(idx) iflen(batch)==self.batch_size: yieldbatch batch=[] iflen(batch)>0andnotself.drop_last: yieldbatch def__len__(self): ifself.drop_last: returnlen(self.sampler)//self.batch_size else: return(len(self.sampler)+self.batch_size-1)//self.batch_size
- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;
- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的
__getitem__()方法
- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;
class_DataLoaderIter(object): r"""IteratesonceovertheDataLoader'sdataset,asspecifiedbythesampler""" def__init__(self,loader): self.dataset=loader.dataset self.collate_fn=loader.collate_fn self.batch_sampler=loader.batch_sampler self.num_workers=loader.num_workers self.pin_memory=loader.pin_memoryandtorch.cuda.is_available() self.timeout=loader.timeout self.done_event=threading.Event() self.sample_iter=iter(self.batch_sampler) base_seed=torch.LongTensor(1).random_().item() ifself.num_workers>0: self.worker_init_fn=loader.worker_init_fn self.index_queues=[multiprocessing.Queue()for_inrange(self.num_workers)] self.worker_queue_idx=0 self.worker_result_queue=multiprocessing.SimpleQueue() self.batches_outstanding=0 self.worker_pids_set=False self.shutdown=False self.send_idx=0 self.rcvd_idx=0 self.reorder_dict={} self.workers=[ multiprocessing.Process( target=_worker_loop, args=(self.dataset,self.index_queues[i], self.worker_result_queue,self.collate_fn,base_seed+i, self.worker_init_fn,i)) foriinrange(self.num_workers)] ifself.pin_memoryorself.timeout>0: self.data_queue=queue.Queue() ifself.pin_memory: maybe_device_id=torch.cuda.current_device() else: #donotinitializecudacontextifnotnecessary maybe_device_id=None self.worker_manager_thread=threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue,self.data_queue,self.done_event,self.pin_memory, maybe_device_id)) self.worker_manager_thread.daemon=True self.worker_manager_thread.start() else: self.data_queue=self.worker_result_queue forwinself.workers: w.daemon=True#ensurethattheworkerexitsonprocessexit w.start() _update_worker_pids(id(self),tuple(w.pidforwinself.workers)) _set_SIGCHLD_handler() self.worker_pids_set=True #primetheprefetchloop for_inrange(2*self.num_workers): self._put_indices() def__len__(self): returnlen(self.batch_sampler) def_get_batch(self): ifself.timeout>0: try: returnself.data_queue.get(timeout=self.timeout) exceptqueue.Empty: raiseRuntimeError('DataLoadertimedoutafter{}seconds'.format(self.timeout)) else: returnself.data_queue.get() def__next__(self): ifself.num_workers==0:#same-processloading indices=next(self.sample_iter)#mayraiseStopIteration batch=self.collate_fn([self.dataset[i]foriinindices]) ifself.pin_memory: batch=pin_memory_batch(batch) returnbatch #checkifthenextsamplehasalreadybeengenerated ifself.rcvd_idxinself.reorder_dict: batch=self.reorder_dict.pop(self.rcvd_idx) returnself._process_next_batch(batch) ifself.batches_outstanding==0: self._shutdown_workers() raiseStopIteration whileTrue: assert(notself.shutdownandself.batches_outstanding>0) idx,batch=self._get_batch() self.batches_outstanding-=1 ifidx!=self.rcvd_idx: #storeout-of-ordersamples self.reorder_dict[idx]=batch continue returnself._process_next_batch(batch) next=__next__#Python2compatibility def__iter__(self): returnself def_put_indices(self): assertself.batches_outstanding<2*self.num_workers indices=next(self.sample_iter,None) ifindicesisNone: return self.index_queues[self.worker_queue_idx].put((self.send_idx,indices)) self.worker_queue_idx=(self.worker_queue_idx+1)%self.num_workers self.batches_outstanding+=1 self.send_idx+=1 def_process_next_batch(self,batch): self.rcvd_idx+=1 self._put_indices() ifisinstance(batch,ExceptionWrapper): raisebatch.exc_type(batch.exc_msg) returnbatch
def_worker_loop(dataset,index_queue,data_queue,collate_fn,seed,init_fn,worker_id): global_use_shared_memory _use_shared_memory=True #IntializeCsidesignalhandlersforSIGBUSandSIGSEGV.Pythonsignal #module'shandlersareexecutedafterPythonreturnsfromClow-level #handlers,likelywhenthesamefatalsignalhappenedagainalready. #https://docs.python.org/3/library/signal.htmlSec.18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) ifinit_fnisnotNone: init_fn(worker_id) watchdog=ManagerWatchdog() whileTrue: try: r=index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) exceptqueue.Empty: ifwatchdog.is_alive(): continue else: break ifrisNone: break idx,batch_indices=r try: samples=collate_fn([dataset[i]foriinbatch_indices]) exceptException: data_queue.put((idx,ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx,samples)) delsamples
- 需要对队列操作,缓存数据,使得加载提速!
关于“pytorch中DataLoader()过程中会遇到的问题有哪些”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。
推荐阅读
-
Pytorch中model.eval()的作用是什么
Pytorch中model.eval()的作用是什么这篇文章主要介...
-
怎么使用pytorch读取数据集
-
pytorch中的view()函数怎么使用
pytorch中的view()函数怎么使用这篇文章主要介绍了pyt...
-
PyTorch中的torch.cat怎么用
PyTorch中的torch.cat怎么用这篇文章主要介绍PyTo...
-
pytorch中的hook机制是什么
pytorch中的hook机制是什么本篇内容介绍了“pytorch...
-
pytorch中的广播语义是什么
pytorch中的广播语义是什么这篇文章主要介绍“pytorch中...
-
PyTorch梯度下降反向传播实例分析
-
python中的Pytorch建模流程是什么
python中的Pytorch建模流程是什么小编给大家分享一下py...
-
pytorch如何实现多项式回归
-
pytorch怎样实现线性回归