pytorch如何实现inception_v3
这篇文章将为大家详细讲解有关pytorch如何实现inception_v3,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
如下所示:
from__future__importprint_function from__future__importdivision importtorch importtorch.nnasnn importtorch.optimasoptim importnumpyasnp importtorchvision fromtorchvisionimportdatasets,models,transforms importmatplotlib.pyplotasplt importtime importos importcopy importargparse print("PyTorchVersion:",torch.__version__) print("TorchvisionVersion:",torchvision.__version__) #Topleveldatadirectory.Hereweassumetheformatofthedirectoryconforms #totheImageFolderstructure
数据集路径,路径下的数据集分为训练集和测试集,也就是train 以及val,train下分为两类数据1,2,val集同理
data_dir="/home/dell/Desktop/data/切割图像" #Modelstochoosefrom[resnet,alexnet,vgg,squeezenet,densenet,inception] model_name="inception" #Numberofclassesinthedataset num_classes=2#两类数据1,2 #Batchsizefortraining(changedependingonhowmuchmemoryyouhave) batch_size=32#batchsize尽量选取合适,否则训练时会内存溢出 #Numberofepochstotrainfor num_epochs=1000 #Flagforfeatureextracting.WhenFalse,wefinetunethewholemodel, #whenTrueweonlyupdatethereshapedlayerparams feature_extract=True #参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多 parser=argparse.ArgumentParser(description='PyTorchinception') parser.add_argument('--outf',default='/home/dell/Desktop/dj/inception/',help='foldertooutputimagesandmodelcheckpoints')#输出结果保存路径 parser.add_argument('--net',default='/home/dell/Desktop/dj/inception/inception.pth',help="pathtonet(tocontinuetraining)")#恢复训练时的模型路径 args=parser.parse_args()
训练函数
deftrain_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False): since=time.time() val_acc_history=[] best_model_wts=copy.deepcopy(model.state_dict()) best_acc=0.0 print("StartTraining,InceptionV3!") withopen("acc.txt","w")asf1: withopen("log.txt","w")asf2: forepochinrange(num_epochs): print('Epoch{}/{}'.format(epoch+1,num_epochs)) print('*'*10) #Eachepochhasatrainingandvalidationphase forphasein['train','val']: ifphase=='train': model.train()#Setmodeltotrainingmode else: model.eval()#Setmodeltoevaluatemode running_loss=0.0 running_corrects=0 #Iterateoverdata. forinputs,labelsindataloaders[phase]: inputs=inputs.to(device) labels=labels.to(device) #zerotheparametergradients optimizer.zero_grad() #forward #trackhistoryifonlyintrain withtorch.set_grad_enabled(phase=='train'): ifis_inceptionandphase=='train': #Fromhttps://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 outputs,aux_outputs=model(inputs) loss1=criterion(outputs,labels) loss2=criterion(aux_outputs,labels) loss=loss1+0.4*loss2 else: outputs=model(inputs) loss=criterion(outputs,labels) _,preds=torch.max(outputs,1) #backward+optimizeonlyifintrainingphase ifphase=='train': loss.backward() optimizer.step() #statistics running_loss+=loss.item()*inputs.size(0) running_corrects+=torch.sum(preds==labels.data) epoch_loss=running_loss/len(dataloaders[phase].dataset) epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset) print('{}Loss:{:.4f}Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc)) f2.write('{}Loss:{:.4f}Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc)) f2.write('\n') f2.flush() #deepcopythemodel ifphase=='val': if(epoch+1)%50==0: #print('Savingmodel......') torch.save(model.state_dict(),'%s/inception_%03d.pth'%(args.outf,epoch+1)) f1.write("EPOCH=%03d,Accuracy=%.3f%%"%(epoch+1,epoch_acc)) f1.write('\n') f1.flush() ifphase=='val'andepoch_acc>best_acc: f3=open("best_acc.txt","w") f3.write("EPOCH=%d,best_acc=%.3f%%"%(epoch+1,epoch_acc)) f3.close() best_acc=epoch_acc best_model_wts=copy.deepcopy(model.state_dict()) ifphase=='val': val_acc_history.append(epoch_acc) time_elapsed=time.time()-since print('Trainingcompletein{:.0f}m{:.0f}s'.format(time_elapsed//60,time_elapsed%60)) print('BestvalAcc:{:4f}'.format(best_acc)) #loadbestmodelweights model.load_state_dict(best_model_wts) returnmodel,val_acc_history #是否更新参数 defset_parameter_requires_grad(model,feature_extracting): iffeature_extracting: forparaminmodel.parameters(): param.requires_grad=False definitialize_model(model_name,num_classes,feature_extract,use_pretrained=True): #Initializethesevariableswhichwillbesetinthisifstatement.Eachofthese #variablesismodelspecific. model_ft=None input_size=0 ifmodel_name=="resnet": """Resnet18 """ model_ft=models.resnet18(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) num_ftrs=model_ft.fc.in_features model_ft.fc=nn.Linear(num_ftrs,num_classes) input_size=224 elifmodel_name=="alexnet": """Alexnet """ model_ft=models.alexnet(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) num_ftrs=model_ft.classifier[6].in_features model_ft.classifier[6]=nn.Linear(num_ftrs,num_classes) input_size=224 elifmodel_name=="vgg": """VGG11_bn """ model_ft=models.vgg11_bn(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) num_ftrs=model_ft.classifier[6].in_features model_ft.classifier[6]=nn.Linear(num_ftrs,num_classes) input_size=224 elifmodel_name=="squeezenet": """Squeezenet """ model_ft=models.squeezenet1_0(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) model_ft.classifier[1]=nn.Conv2d(512,num_classes,kernel_size=(1,1),stride=(1,1)) model_ft.num_classes=num_classes input_size=224 elifmodel_name=="densenet": """Densenet """ model_ft=models.densenet121(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) num_ftrs=model_ft.classifier.in_features model_ft.classifier=nn.Linear(num_ftrs,num_classes) input_size=224 elifmodel_name=="inception": """Inceptionv3 Becareful,expects(299,299)sizedimagesandhasauxiliaryoutput """ model_ft=models.inception_v3(pretrained=use_pretrained) set_parameter_requires_grad(model_ft,feature_extract) #Handletheauxilarynet num_ftrs=model_ft.AuxLogits.fc.in_features model_ft.AuxLogits.fc=nn.Linear(num_ftrs,num_classes) #Handletheprimarynet num_ftrs=model_ft.fc.in_features model_ft.fc=nn.Linear(num_ftrs,num_classes) input_size=299 else: print("Invalidmodelname,exiting...") exit() returnmodel_ft,input_size #Initializethemodelforthisrun model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True) #Printthemodelwejustinstantiated #print(model_ft) #准备数据 data_transforms={ 'train':transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]), 'val':transforms.Compose([ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]), } print("InitializingDatasetsandDataloaders...") #Createtrainingandvalidationdatasets image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])forxin['train','val']} #Createtrainingandvalidationdataloaders dataloaders_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True,num_workers=0)forxin['train','val']} #DetectifwehaveaGPUavailable device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu") ''' 是否加载之前训练过的模型 we='/home/dell/Desktop/dj/inception_050.pth' model_ft.load_state_dict(torch.load(we)) ''' #SendthemodeltoGPU model_ft=model_ft.to(device) params_to_update=model_ft.parameters() print("Paramstolearn:") iffeature_extract: params_to_update=[] forname,paraminmodel_ft.named_parameters(): ifparam.requires_grad==True: params_to_update.append(param) print("\t",name) else: forname,paraminmodel_ft.named_parameters(): ifparam.requires_grad==True: print("\t",name) #Observethatallparametersarebeingoptimized optimizer_ft=optim.SGD(params_to_update,lr=0.001,momentum=0.9) #DecayLRbyafactorof0.1every7epochs #exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=30,gamma=0.95) #Setupthelossfxn criterion=nn.CrossEntropyLoss() #Trainandevaluate model_ft,hist=train_model(model_ft,dataloaders_dict,criterion,optimizer_ft,num_epochs=num_epochs,is_inception=(model_name=="inception")) ''' #随机初始化时的训练程序 #Initializethenon-pretrainedversionofthemodelusedforthisrun scratch_model,_=initialize_model(model_name,num_classes,feature_extract=False,use_pretrained=False) scratch_model=scratch_model.to(device) scratch_optimizer=optim.SGD(scratch_model.parameters(),lr=0.001,momentum=0.9) scratch_criterion=nn.CrossEntropyLoss() _,scratch_hist=train_model(scratch_model,dataloaders_dict,scratch_criterion,scratch_optimizer,num_epochs=num_epochs,is_inception=(model_name=="inception")) #Plotthetrainingcurvesofvalidationaccuracyvs.number #oftrainingepochsforthetransferlearningmethodand #themodeltrainedfromscratch ohist=[] shist=[] ohist=[h.cpu().numpy()forhinhist] shist=[h.cpu().numpy()forhinscratch_hist] plt.title("ValidationAccuracyvs.NumberofTrainingEpochs") plt.xlabel("TrainingEpochs") plt.ylabel("ValidationAccuracy") plt.plot(range(1,num_epochs+1),ohist,label="Pretrained") plt.plot(range(1,num_epochs+1),shist,label="Scratch") plt.ylim((0,1.)) plt.xticks(np.arange(1,num_epochs+1,1.0)) plt.legend() plt.show() '''
pytorch的优点
1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单
关于“pytorch如何实现inception_v3”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。
推荐阅读
-
Pytorch中model.eval()的作用是什么
Pytorch中model.eval()的作用是什么这篇文章主要介...
-
怎么使用pytorch读取数据集
-
pytorch中的view()函数怎么使用
pytorch中的view()函数怎么使用这篇文章主要介绍了pyt...
-
PyTorch中的torch.cat怎么用
PyTorch中的torch.cat怎么用这篇文章主要介绍PyTo...
-
pytorch中的hook机制是什么
pytorch中的hook机制是什么本篇内容介绍了“pytorch...
-
pytorch中的广播语义是什么
pytorch中的广播语义是什么这篇文章主要介绍“pytorch中...
-
PyTorch梯度下降反向传播实例分析
-
python中的Pytorch建模流程是什么
python中的Pytorch建模流程是什么小编给大家分享一下py...
-
pytorch如何实现多项式回归
-
pytorch怎样实现线性回归