A simple example of GAN implementation using PyTorch

Below is a simple example of implementing a GAN using PyTorch. It is assumed here that you already understand the basic principles of GAN and have PyTorch installed.

First, we need to define the generator and discriminator. The generator is a neural network that takes random noise as input and outputs a fake image. The discriminator is another neural network that takes an image as input and outputs a value indicating whether the image is real or not.

```python import torch import torch.nn as nn import torch.optim as optim

class Generator(nn.Module): def init(self, latent_dim, img_shape): super().init() self.img_shape = img_shape

self.model = nn.Sequential(
        # 输入为随机噪声, 输出为 (batch_size, 128, 4, 4)
        nn.Linear(latent_dim, 128 * 4 * 4),
        nn.BatchNorm1d(128 * 4 * 4),
        nn.LeakyReLU(0.01),
        nn.Reshape(128, 4, 4),
        # 输出为 (batch_size, 128, 8, 8)
        nn.ConvTranspose2d(128, 128, 4, 2, 1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.01),
        # 输出为 (batch_size, 128, 16, 16)
        nn.ConvTranspose2d(128, 128, 4, 2, 1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.01),
        # 输出为 (batch_size, 3, 32, 32)
        nn.ConvTranspose2d(128, 3, 4, 2, 1),
        nn.Tanh()
    )

def forward(self, z):
    return self.model(z)

class Discriminator(nn.Module): def init(self, img_shape): super().init() self.model = nn.Sequential( # 输入为 (batch_size, 3, 32, 32) nn.Conv2d(3, 128, 4, 2, 1),

Guess you like

Origin blog.csdn.net/weixin_42609225/article/details/129502144