pytorch:怎么实现简单的GAN
作者
小编给大家分享一下pytorch:怎么实现简单的GAN,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!
代码如下
#-*-coding:utf-8-*- """ CreatedonSatOct1310:22:452018 """ importtorch fromtorchimportnn fromtorch.autogradimportVariable importtorchvision.transformsastfs fromtorch.utils.dataimportDataLoader,sampler fromtorchvision.datasetsimportMNIST importnumpyasnp importmatplotlib.pyplotasplt importmatplotlib.gridspecasgridspec plt.rcParams['figure.figsize']=(10.0,8.0)#设置画图的尺寸 plt.rcParams['image.interpolation']='nearest' plt.rcParams['image.cmap']='gray' defshow_images(images):#定义画图工具 images=np.reshape(images,[images.shape[0],-1]) sqrtn=int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg=int(np.ceil(np.sqrt(images.shape[1]))) fig=plt.figure(figsize=(sqrtn,sqrtn)) gs=gridspec.GridSpec(sqrtn,sqrtn) gs.update(wspace=0.05,hspace=0.05) fori,imginenumerate(images): ax=plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg,sqrtimg])) return defpreprocess_img(x): x=tfs.ToTensor()(x) return(x-0.5)/0.5 defdeprocess_img(x): return(x+1.0)/2.0 classChunkSampler(sampler.Sampler):#定义一个取样的函数 """Sampleselementssequentiallyfromsomeoffset. Arguments: num_samples:#ofdesireddatapoints start:offsetwhereweshouldstartselectingfrom """ def__init__(self,num_samples,start=0): self.num_samples=num_samples self.start=start def__iter__(self): returniter(range(self.start,self.start+self.num_samples)) def__len__(self): returnself.num_samples NUM_TRAIN=50000 NUM_VAL=5000 NOISE_DIM=96 batch_size=128 train_set=MNIST('E:/data',train=True,transform=preprocess_img) train_data=DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0)) val_set=MNIST('E:/data',train=True,transform=preprocess_img) val_data=DataLoader(val_set,batch_size=batch_size,sampler=ChunkSampler(NUM_VAL,NUM_TRAIN)) imgs=deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze()#可视化图片效果 show_images(imgs) #判别网络 defdiscriminator(): net=nn.Sequential( nn.Linear(784,256), nn.LeakyReLU(0.2), nn.Linear(256,256), nn.LeakyReLU(0.2), nn.Linear(256,1) ) returnnet #生成网络 defgenerator(noise_dim=NOISE_DIM): net=nn.Sequential( nn.Linear(noise_dim,1024), nn.ReLU(True), nn.Linear(1024,1024), nn.ReLU(True), nn.Linear(1024,784), nn.Tanh() ) returnnet #判别器的loss就是将真实数据的得分判断为1,假的数据的得分判断为0,而生成器的loss就是将假的数据判断为1 bce_loss=nn.BCEWithLogitsLoss()#交叉熵损失函数 defdiscriminator_loss(logits_real,logits_fake):#判别器的loss size=logits_real.shape[0] true_labels=Variable(torch.ones(size,1)).float() false_labels=Variable(torch.zeros(size,1)).float() loss=bce_loss(logits_real,true_labels)+bce_loss(logits_fake,false_labels) returnloss defgenerator_loss(logits_fake):#生成器的loss size=logits_fake.shape[0] true_labels=Variable(torch.ones(size,1)).float() loss=bce_loss(logits_fake,true_labels) returnloss #使用adam来进行训练,学习率是3e-4,beta1是0.5,beta2是0.999 defget_optimizer(net): optimizer=torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5,0.999)) returnoptimizer deftrain_a_gan(D_net,G_net,D_optimizer,G_optimizer,discriminator_loss,generator_loss,show_every=250, noise_size=96,num_epochs=10): iter_count=0 forepochinrange(num_epochs): forx,_intrain_data: bs=x.shape[0] #判别网络 real_data=Variable(x).view(bs,-1)#真实数据 logits_real=D_net(real_data)#判别网络得分 sample_noise=(torch.rand(bs,noise_size)-0.5)/0.5#-1~1的均匀分布 g_fake_seed=Variable(sample_noise) fake_images=G_net(g_fake_seed)#生成的假的数据 logits_fake=D_net(fake_images)#判别网络得分 d_total_error=discriminator_loss(logits_real,logits_fake)#判别器的loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step()#优化判别网络 #生成网络 g_fake_seed=Variable(sample_noise) fake_images=G_net(g_fake_seed)#生成的假的数据 gen_logits_fake=D_net(fake_images) g_error=generator_loss(gen_logits_fake)#生成网络的loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step()#优化生成网络 if(iter_count%show_every==0): print('Iter:{},D:{:.4},G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item())) imgs_numpy=deprocess_img(fake_images.data.cpu().numpy()) show_images(imgs_numpy[0:16]) plt.show() print() iter_count+=1 D=discriminator() G=generator() D_optim=get_optimizer(D) G_optim=get_optimizer(G) train_a_gan(D,G,D_optim,G_optim,discriminator_loss,generator_loss)
以上是“pytorch:怎么实现简单的GAN”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注恰卡编程网行业资讯频道!
目录
推荐阅读
-
Pytorch中model.eval()的作用是什么
Pytorch中model.eval()的作用是什么这篇文章主要介...
-
怎么使用pytorch读取数据集
-
pytorch中的view()函数怎么使用
pytorch中的view()函数怎么使用这篇文章主要介绍了pyt...
-
PyTorch中的torch.cat怎么用
PyTorch中的torch.cat怎么用这篇文章主要介绍PyTo...
-
pytorch中的hook机制是什么
pytorch中的hook机制是什么本篇内容介绍了“pytorch...
-
pytorch中的广播语义是什么
pytorch中的广播语义是什么这篇文章主要介绍“pytorch中...
-
PyTorch梯度下降反向传播实例分析
-
python中的Pytorch建模流程是什么
python中的Pytorch建模流程是什么小编给大家分享一下py...
-
pytorch如何实现多项式回归
-
pytorch怎样实现线性回归
0 条评论
本站已关闭游客评论,请登录或者注册后再评论吧~