pytorch:怎么实现简单的GAN

小编给大家分享一下pytorch:怎么实现简单的GAN,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

代码如下

pytorch:怎么实现简单的GAN

#-*-coding:utf-8-*-
"""
CreatedonSatOct1310:22:452018
"""

importtorch
fromtorchimportnn
fromtorch.autogradimportVariable

importtorchvision.transformsastfs
fromtorch.utils.dataimportDataLoader,sampler
fromtorchvision.datasetsimportMNIST

importnumpyasnp

importmatplotlib.pyplotasplt
importmatplotlib.gridspecasgridspec

plt.rcParams['figure.figsize']=(10.0,8.0)#设置画图的尺寸
plt.rcParams['image.interpolation']='nearest'
plt.rcParams['image.cmap']='gray'

defshow_images(images):#定义画图工具
images=np.reshape(images,[images.shape[0],-1])
sqrtn=int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg=int(np.ceil(np.sqrt(images.shape[1])))

fig=plt.figure(figsize=(sqrtn,sqrtn))
gs=gridspec.GridSpec(sqrtn,sqrtn)
gs.update(wspace=0.05,hspace=0.05)

fori,imginenumerate(images):
ax=plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return

defpreprocess_img(x):
x=tfs.ToTensor()(x)
return(x-0.5)/0.5

defdeprocess_img(x):
return(x+1.0)/2.0

classChunkSampler(sampler.Sampler):#定义一个取样的函数
"""Sampleselementssequentiallyfromsomeoffset.
Arguments:
num_samples:#ofdesireddatapoints
start:offsetwhereweshouldstartselectingfrom
"""
def__init__(self,num_samples,start=0):
self.num_samples=num_samples
self.start=start

def__iter__(self):
returniter(range(self.start,self.start+self.num_samples))

def__len__(self):
returnself.num_samples

NUM_TRAIN=50000
NUM_VAL=5000

NOISE_DIM=96
batch_size=128

train_set=MNIST('E:/data',train=True,transform=preprocess_img)

train_data=DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0))

val_set=MNIST('E:/data',train=True,transform=preprocess_img)

val_data=DataLoader(val_set,batch_size=batch_size,sampler=ChunkSampler(NUM_VAL,NUM_TRAIN))

imgs=deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze()#可视化图片效果
show_images(imgs)

#判别网络
defdiscriminator():
net=nn.Sequential(
nn.Linear(784,256),
nn.LeakyReLU(0.2),
nn.Linear(256,256),
nn.LeakyReLU(0.2),
nn.Linear(256,1)
)
returnnet

#生成网络
defgenerator(noise_dim=NOISE_DIM):
net=nn.Sequential(
nn.Linear(noise_dim,1024),
nn.ReLU(True),
nn.Linear(1024,1024),
nn.ReLU(True),
nn.Linear(1024,784),
nn.Tanh()
)
returnnet

#判别器的loss就是将真实数据的得分判断为1,假的数据的得分判断为0,而生成器的loss就是将假的数据判断为1

bce_loss=nn.BCEWithLogitsLoss()#交叉熵损失函数

defdiscriminator_loss(logits_real,logits_fake):#判别器的loss
size=logits_real.shape[0]
true_labels=Variable(torch.ones(size,1)).float()
false_labels=Variable(torch.zeros(size,1)).float()
loss=bce_loss(logits_real,true_labels)+bce_loss(logits_fake,false_labels)
returnloss

defgenerator_loss(logits_fake):#生成器的loss
size=logits_fake.shape[0]
true_labels=Variable(torch.ones(size,1)).float()
loss=bce_loss(logits_fake,true_labels)
returnloss

#使用adam来进行训练,学习率是3e-4,beta1是0.5,beta2是0.999
defget_optimizer(net):
optimizer=torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5,0.999))
returnoptimizer

deftrain_a_gan(D_net,G_net,D_optimizer,G_optimizer,discriminator_loss,generator_loss,show_every=250,
noise_size=96,num_epochs=10):
iter_count=0
forepochinrange(num_epochs):
forx,_intrain_data:
bs=x.shape[0]
#判别网络
real_data=Variable(x).view(bs,-1)#真实数据
logits_real=D_net(real_data)#判别网络得分

sample_noise=(torch.rand(bs,noise_size)-0.5)/0.5#-1~1的均匀分布
g_fake_seed=Variable(sample_noise)
fake_images=G_net(g_fake_seed)#生成的假的数据
logits_fake=D_net(fake_images)#判别网络得分

d_total_error=discriminator_loss(logits_real,logits_fake)#判别器的loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step()#优化判别网络

#生成网络
g_fake_seed=Variable(sample_noise)
fake_images=G_net(g_fake_seed)#生成的假的数据

gen_logits_fake=D_net(fake_images)
g_error=generator_loss(gen_logits_fake)#生成网络的loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step()#优化生成网络

if(iter_count%show_every==0):
print('Iter:{},D:{:.4},G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item()))
imgs_numpy=deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count+=1

D=discriminator()
G=generator()

D_optim=get_optimizer(D)
G_optim=get_optimizer(G)

train_a_gan(D,G,D_optim,G_optim,discriminator_loss,generator_loss)

以上是“pytorch:怎么实现简单的GAN”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注恰卡编程网行业资讯频道!

发布于 2021-05-30 14:06:09
收藏
分享
海报
0 条评论
175
上一篇:iOS中无卡顿同时使用圆角、阴影和边框的实现 下一篇:Python GUI自动化怎么实现绕过验证码登录
目录

    0 条评论

    本站已关闭游客评论,请登录或者注册后再评论吧~

    忘记密码?

    图形验证码