Generate confrontation network GAN and image reshape method

1. GANs

1 Introduction

Generative Adversarial Networks (GAN for short) is one of the most important research hotspots in the field of artificial intelligence. Its outstanding generation ability can not only be used to generate various images and natural language data, but also inspire and promote the development of various semi-supervised learning and unsupervised learning tasks. It mainly includes Generative Model and Discriminative Model. A discriminative model requires input variables to be predicted by some model. A generative model is to randomly generate observation data given some hidden information.

2. Model structure

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
from tqdm import tqdm

ROOT_TRAIN = r'D:\cnn\data1\train'

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=4,
                                           shuffle=True,
                                           num_workers=0)


# 定义生成器,输入是长度为100的噪声(正态分布随机数)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 1024),
            nn.ReLU(),
            nn.Linear(1024, 3*224*224),
            nn.Tanh(),
        )
    def forward(self, x): #x为噪声输入
        img = self.main(x)
        img = img.view(-1, 3, 224, 224)
        return img

# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(3*224*224, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 3*224*224)
        x = self.main(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=0.001)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为Generator,test_input代表生成器输入的随机数
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=test_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./data/GANimage_at_{}.png'.format(epoch)) #把每一轮生成的图片保存到文件夹data中

test_input = torch.randn(4, 100, device=device) # 16个长度为100的随机数

# GAN训练
D_loss = []
G_loss = []

for epoch in range(20):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise) #得到生成图像
        fake_output = dis(gen_img.detach()) #将生成图像放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化

        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        print('Epoch:', epoch)
        gen_img_plot(gen, epoch, test_input) #每个epoch会生成一张图

2. Picture reshape

1. Resize directly

# -*- coding: utf-8 -*-
from PIL import Image
import os

def image_resize(image_path, new_path):
    for img_name in os.listdir(image_path):
        img_path = image_path + "/" + img_name  # 获取该图片全称
        image = Image.open(img_path)  # 打开图片
        width = image.size[0] #宽
        high = image.size[1] #高
        mask = Image.new('RGB', (width, high))  # 新建一个正方形mask,RGB代表3*8位像素
        mask.paste(image, (0, 0))
        mask = mask.resize((224, 224))
        mask.save(new_path + '/' + img_name)


if __name__ == '__main__':
    ori_path = r"D:\cnn\All Classfication\AlexNet\data\train\Cat"  # 输入图片的文件夹路径
    new_path = 'D:\cnn\All Classfication\AlexNet\data1/train\Cat'  # resize之后的文件夹路径
    image_resize(ori_path, new_path)

2. Undistorted resize

# -*- coding: utf-8 -*-
from PIL import Image
import os

def image_resize(image_path, new_path):
    for img_name in os.listdir(image_path):
        img_path = image_path + "/" + img_name
        image = Image.open(img_path)  # 打开一张图片
        temp = max(image.size)
        mask = Image.new('RGB', (temp, temp), (255, 255, 255))  # 新建一个正方形mask,RGB代表3*8位像素,255为填充白色
        mask.paste(image, (0, 0))
        mask = mask.resize((224, 224))
        mask.save(new_path + '/' + img_name)


if __name__ == '__main__':
    ori_path = r"D:\cnn\All Classfication\AlexNet\data\train\Cat"  # 输入图片的文件夹路径
    new_path = 'D:\cnn\All Classfication\AlexNet\data1/train\Cat'  # resize之后的文件夹路径
    image_resize(ori_path, new_path)

                             

 

 

Guess you like

Origin blog.csdn.net/m0_56247038/article/details/130252435