生成对抗网络GAN及图片reshape方法

一、GAN

1.介绍

生成对抗网络(Generative Adversarial Networks, 简称GAN)是当前人工智能学界最为重要的研究热点之一。其突出的生成能力不仅可用于生成各类图像和自然语言数据,还启发和推动了各类半监督学习和无监督学习任务的发展。主要包含生成模型( Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来预测 。生成模型是给定某种隐含信息,来随机产生观测数据。

2.模型结构

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
from tqdm import tqdm

ROOT_TRAIN = r'D:\cnn\data1\train'

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=4,
                                           shuffle=True,
                                           num_workers=0)


# 定义生成器,输入是长度为100的噪声(正态分布随机数)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 1024),
            nn.ReLU(),
            nn.Linear(1024, 3*224*224),
            nn.Tanh(),
        )
    def forward(self, x): #x为噪声输入
        img = self.main(x)
        img = img.view(-1, 3, 224, 224)
        return img

# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(3*224*224, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 3*224*224)
        x = self.main(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=0.001)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为Generator,test_input代表生成器输入的随机数
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=test_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./data/GANimage_at_{}.png'.format(epoch)) #把每一轮生成的图片保存到文件夹data中

test_input = torch.randn(4, 100, device=device) # 16个长度为100的随机数

# GAN训练
D_loss = []
G_loss = []

for epoch in range(20):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise) #得到生成图像
        fake_output = dis(gen_img.detach()) #将生成图像放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化

        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        print('Epoch:', epoch)
        gen_img_plot(gen, epoch, test_input) #每个epoch会生成一张图

二、图片reshape

1.直接进行resize

# -*- coding: utf-8 -*-
from PIL import Image
import os

def image_resize(image_path, new_path):
    for img_name in os.listdir(image_path):
        img_path = image_path + "/" + img_name  # 获取该图片全称
        image = Image.open(img_path)  # 打开图片
        width = image.size[0] #宽
        high = image.size[1] #高
        mask = Image.new('RGB', (width, high))  # 新建一个正方形mask,RGB代表3*8位像素
        mask.paste(image, (0, 0))
        mask = mask.resize((224, 224))
        mask.save(new_path + '/' + img_name)


if __name__ == '__main__':
    ori_path = r"D:\cnn\All Classfication\AlexNet\data\train\Cat"  # 输入图片的文件夹路径
    new_path = 'D:\cnn\All Classfication\AlexNet\data1/train\Cat'  # resize之后的文件夹路径
    image_resize(ori_path, new_path)

2.不失真的resize

# -*- coding: utf-8 -*-
from PIL import Image
import os

def image_resize(image_path, new_path):
    for img_name in os.listdir(image_path):
        img_path = image_path + "/" + img_name
        image = Image.open(img_path)  # 打开一张图片
        temp = max(image.size)
        mask = Image.new('RGB', (temp, temp), (255, 255, 255))  # 新建一个正方形mask,RGB代表3*8位像素,255为填充白色
        mask.paste(image, (0, 0))
        mask = mask.resize((224, 224))
        mask.save(new_path + '/' + img_name)


if __name__ == '__main__':
    ori_path = r"D:\cnn\All Classfication\AlexNet\data\train\Cat"  # 输入图片的文件夹路径
    new_path = 'D:\cnn\All Classfication\AlexNet\data1/train\Cat'  # resize之后的文件夹路径
    image_resize(ori_path, new_path)

                             

猜你喜欢

转载自blog.csdn.net/m0_56247038/article/details/130252435