pytorch十五:生成对抗网络-mnist

GAN解决了非监督学习中的著名问题:给定一批样本,训练一个系统能够生成类似的新样本。

生成对抗网络主要包含以下两个子网络:

  • 生成器:随机生成一个噪声,生成一张图片
  • 判别器:判断输入的图片是真图片还是假图片

 交替训练:

  • 训练判别器时,需要利用真实图片和生成器生成的假图片,判别器希望判别真实图片尽可能为真,判别生成器生成的图片尽可能为假。(判别器希望能够尽可能地判别真假)
  • 训练生成器时,只需要利用生成器生成的图片,将生成器生成的图片放到判别器中,判别器判别其尽可能为真。(生成器希望生成的图片尽可能为真)

训练到一定阶段,判别器和生成器会达到一个平衡 。即此时生成器生成的图片足以以假乱真,足以欺骗到判别器了。

对于生成器,其网络结构类似于下面,当然具体的通道数、步长、核尺寸、填充等,可根据具体的实例进行适当修改。

上面网络的输入是一个100维的噪声,输出是一个3x64x64的图片。这里的输入可以看成是一个100x1x1的图片,通过反卷积(转置卷积)慢慢增大为4x4、8x8、16x16、32x32、64x64。这种反卷积的做法可以理解为图片的信息保存于100个向量之中,神经网络根据这100个向量描述的信息,前几步的反卷积先勾勒出轮廓、色调等基础信息,后几步反卷积慢慢完善细节。网络越深,细节越详细。

转置卷积后特征图的尺寸大小为Hout = (Hin-1) x S + K - 2P  (S为步长Stride,K为核大小Kernel,P为填充层Padding)。

步骤:

扫描二维码关注公众号,回复: 6635192 查看本文章
  • 定义模型
  • 数据加载
  • 参数配置
  • 模型训练

方法一:使用全连接层神经网络 

定义模型:命名为mnist_model.py文件,放到models文件夹下

#判别器
#将图片28*28展开成784,然后通过多层感知器,最后接sigmoid激活得到0到1之间的概率进行二分类
from torch import nn
class NetD(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Linear(784,256),  #输入特征为784,输出为256
            nn.LeakyReLU(0.2,inplace=True),

            nn.Linear(256,256), #进行一个线性映射
            nn.LeakyReLU(0.2,inplace=True),

            nn.Linear(256,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x = self.main(x)
        return x

#随机输入一个100维的噪声,噪声为均值为0方差为1的高斯分布
class NetG(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Linear(100,256),#用线性变换将输入映射到256
            nn.ReLU(inplace=True),
            
            nn.Linear(256,256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256,784),
            nn.Tanh()   #Tanh激活使得生成数据分布在【-1,1】之间
        )
    def forward(self,x):
        x = self.main(x)
        return x

数据加载:

from torch.utils import data
from torchvision import datasets
from torchvision import transforms as T

img_transform=T.Compose([
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])

mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,drop_last=True)

参数配置:命名为config.py文件,放到config文件夹下

class Config(object):
    data_path = 'data/'#数据集存放路径
    num_workers = 4    #多进程加载数据所用的进程数
    image_size = 96    #图片尺寸
    batch_size = 256   #批量大小 
    max_epoch = 1000
    lr1 = 4e-4       #生成器的学习率
    lr2 = 4e-4       #判别器的学习率
    nz = 100         #噪声维度
    ngf = 64         #生成器feature map数
    ndf = 64         #判别器feature map数
    
    save_path = 'imgs/' #生成器图片保存路径
    
    d_every = 1         #每一个batch训练一次判别器
    g_every = 5         #每5个batch训练一次生成器
    decay_every = 10    #每10个epoch保存一次模型
    
    #预训练模型路径
    netd_path = None 
    netg_path = None
    
    #测试时用的参数
    gen_img = 'result.png'
    #从128张生成的图片中保存最好的24张
    gen_num = 24
    gen_search_num = 128
    gen_mean = 0       #噪声的均值
    gen_std  = 1        #噪声的标准差

模型训练:命名为main.py文件

#导入相关包
import fire
import torch as t
from torch.autograd import Variable as V
from torch.utils import data
from torchvision import datasets
from torchvision import transforms as T
from models.DCGAN_mnist_model import NetG,NetD
from config.config import Config

#数据加载
img_transform=T.Compose([
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,drop_last=True)


#模型训练
opt = Config()
def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
        
    #step1:模型
    netg = NetG()
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path))        
    netd = NetD()
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path))  
    
    #step2:定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),lr = opt.lr1)
    optimizer_d = t.optim.Adam(netd.parameters(),lr = opt.lr2)
    
    #BCELoss:Binary CrossEntropyLoss的缩写,是CrossEntropyLoss的一个特例,只用于二分类问题
    criterion = t.nn.BCELoss()
    
    #真图片label为1,假图片label为0,noises为生成网络的输入噪声
    true_labels = V(t.ones(opt.batch_size))
    fake_labels = V(t.zeros(opt.batch_size))
    
    for epoch in range(opt.max_epoch):
        for i,(datas,labels) in enumerate(mnist_loader):
            num_imgs = len(datas)
            real_img = V(datas.view(num_imgs,-1))
            
            #训练判别器
            #尽可能把真图片判别为1
            output = netd(real_img)
            error_d_real = criterion(output,true_labels)
                
            #尽可能把假图片判别为0
            noises = V(t.randn(num_imgs,opt.nz))
            #训练判别器时需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中,因为训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
            fake_img = netg(noises).detach()
            fake_out = netd(fake_img)
            error_d_fake = criterion(fake_out,fake_labels)
            
            d_loss = error_d_real + error_d_fake
            #梯度清零    
            optimizer_d.zero_grad()
            #反向传播
            d_loss.backward()
            #梯度更新
            optimizer_d.step()
                
            #训练生成器                        
            noises = V(t.randn(num_imgs,opt.nz))
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            #尽可能让判别器把假图片也判别为1
            error_g = criterion(fake_output,true_labels)
            optimizer_g.zero_grad()
            error_g.backward()
            optimizer_g.step()
            
        #保存模型
        if epoch % opt.decay_every==0:
            t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)

#加载训练好的模型,并利用噪声随机生成图片
def generate(**kwargs):
    for k,v in kwargs.items():
        setattr(opt,k,v)
    
    netg,netd = NetG().eval(),NetD().eval()
    noises = t.randn(opt.gen_search_num,opt.nz)
    with t.no_grad():
        nosies = V(noises)
    #加载预训练模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))
    
    #生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data.squeeze()
    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1].squeeze() #[0]为前k个最大的数,[1]为其对应的索引
    from torchvision.utils import save_image
    fake_img = fake_img*0.5 + 0.5
    fake_img = fake_img.clamp(0,1)
    for i in indexs:
        save_image((fake_img.data[i].view(28,28)),filename='imgs/%d.png' %(i))
    return fake_img

if __name__=='__main__':
        fire.Fire()

在命令行输入python main.py train进行训练,训练到一定的epoch可以人为终止,训练完后输入python generate netd_path='checkpoints/netd_50.pth' netg_path='checkpoints/netg_50.pth'

查看迭代50个epoch时生成对抗网络生成的mnist手写字体如下:

50epoch
50epoch

改变netd_path和netg_path路径,分别查看迭代100、150、200个epoch时GAN生成的mnist:

100epoch
150epoch
200epoch

方法二:使用深度卷积神经网络

模型定义:

from torch import nn
#生成器网络
class NetG(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100,128,3,1,0,bias=False), #输出3*3
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128,64,3,2,0,bias=False),#输出7*7
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64,32,3,2,0,bias=False),#输出15*15
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32,1,3,2,2,bias=False),#输出单通道的27*27
            nn.Tanh()  #通过双曲正切函数将输出映射到【-1,1】之间
        )
    def forward(self,x):
        x = self.main(x)
        return x

#判别器网络
class NetD(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Conv2d(1,32,3,2,2,bias=False),#输出为15*15
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(32,64,3,2,0,bias=False), #输出为7*7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,128,3,2,0,bias=False), #输出为3*3
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(128,1,3,1,0,bias=False), #输出位1*1
            nn.Sigmoid()    #映射为0到1   
        )
    def forward(self,x):
        x = self.main(x)
        return x

数据加载:

#深度卷积网络数据预处理方法
img_transform2=T.Compose([
    T.CenterCrop(27),
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform2)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,num_workers=4,drop_last=True)

模型训练:

#导入相关包
import fire
import torch as t
from torch.utils import data
from torch.autograd import Variable as V
from torchvision import datasets
from torchvision import transforms as T
from models.DCGAN_mnist_model import NetG,NetD
from config.config import Config


#数据加载
opt = Config()

img_transform=T.Compose([
    T.CenterCrop(27),
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,num_workers=opt.num_workers,drop_last=True)


#模型训练
def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
        
    #step1:模型
    netg = NetG()
    if opt.netg_path:
        netg.load_state_dict(t.load(None))        
    netd = NetD()
    if opt.netd_path:
        netd.load_state_dict(t.load(None))  
    
    #step2:定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),lr = opt.lr1)
    optimizer_d = t.optim.Adam(netd.parameters(),lr = opt.lr2)
    
    #BCELoss:Binary CrossEntropyLoss的缩写,是CrossEntropyLoss的一个特例,只用于二分类问题
    criterion = t.nn.BCELoss()
    
    #真图片label为1,假图片label为0,noises为生成网络的输入噪声
    true_labels = V(t.ones(opt.batch_size))
    fake_labels = V(t.zeros(opt.batch_size))
    
    for epoch in range(opt.max_epoch):
        for i,(datas,labels) in enumerate(mnist_loader):
            num_imgs = len(datas)
            real_img = V(datas) #用全连接这个地方要改成V(datas.view(num_imgs,-1))
            
            #训练判别器
            #尽可能把真图片判别为1
            output = netd(real_img)
            error_d_real = criterion(output,true_labels) 
            #尽可能把假图片判别为0
            noises = V(t.randn(num_imgs,opt.nz,1,1)) #用全连接这个地方要改成V(t.randn(num_imgs,opt.nz)
            fake_img = netg(noises).detach()
            fake_out = netd(fake_img)
            error_d_fake = criterion(fake_out,fake_labels)
            
            d_loss = error_d_real + error_d_fake
            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()
                
            #训练生成器                        
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            #尽可能让判别器把假图片也判别为1
            error_g = criterion(fake_output,true_labels)
            optimizer_g.zero_grad()
            error_g.backward()
            optimizer_g.step()
            
        #保存模型
        if epoch % opt.decay_every==0:
            print('epoch:{迭代次数}'.format(迭代次数=epoch))
            t.save(netd.state_dict(),'checkpoints2/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints2/netg_%s.pth' %epoch)


#加载训练好的模型,并利用噪声随机生成图片  
def generate(**kwargs):
    for k,v in kwargs.items():
        setattr(opt,k,v)
    
    netg,netd = NetG().eval(),NetD().eval()
    noises = t.randn(opt.gen_search_num,opt.nz)
    with t.no_grad():
        nosies = V(noises)
    #加载预训练模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))
    
    #生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data.squeeze()
    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1].squeeze() #[0]为前k个最大的数,[1]为其对应的索引
    from torchvision.utils import save_image
    fake_img = fake_img*0.5 + 0.5
    fake_img = fake_img.clamp(0,1)
    #fake_img = fake_img.view(-1,1,28,28)
    for i in indexs:
        save_image((fake_img.data[i].view(28,28)),filename='imgs/%d.png' %(i))

if __name__=='__main__':
        fire.Fire()

迭代10epoch和20epoch生成的mnist字体如下:

10epoch
20epoch

 深度卷积神经网络迭代20epoch就能取得不错的效果而且几乎没有噪音,而全卷积神经网络迭代200epoch才取得不错的效果还存在噪音。所以深度卷积神经网络能取得更好的效果。

猜你喜欢

转载自blog.csdn.net/qq_24946843/article/details/89818062