Below is a simple example of implementing a GAN using PyTorch. It is assumed here that you already understand the basic principles of GAN and have PyTorch installed.
First, we need to define the generator and discriminator. The generator is a neural network that takes random noise as input and outputs a fake image. The discriminator is another neural network that takes an image as input and outputs a value indicating whether the image is real or not.
```python import torch import torch.nn as nn import torch.optim as optim
class Generator(nn.Module): def init(self, latent_dim, img_shape): super().init() self.img_shape = img_shape
self.model = nn.Sequential(
# 输入为随机噪声, 输出为 (batch_size, 128, 4, 4)
nn.Linear(latent_dim, 128 * 4 * 4),
nn.BatchNorm1d(128 * 4 * 4),
nn.LeakyReLU(0.01),
nn.Reshape(128, 4, 4),
# 输出为 (batch_size, 128, 8, 8)
nn.ConvTranspose2d(128, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.01),
# 输出为 (batch_size, 128, 16, 16)
nn.ConvTranspose2d(128, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.01),
# 输出为 (batch_size, 3, 32, 32)
nn.ConvTranspose2d(128, 3, 4, 2, 1),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
class Discriminator(nn.Module): def init(self, img_shape): super().init() self.model = nn.Sequential( # 输入为 (batch_size, 3, 32, 32) nn.Conv2d(3, 128, 4, 2, 1),