import time
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
#########################
## SETTINGS
#########################
# Device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# Hyperparameters
random_seed = 123
generator_learning_rate = 0.001
discriminator_learning_rate = 0.001
num_epochs = 100
batch_size = 128
LATENT_DIM = 100
IMG_SHAPE = (1, 28, 28)
IMG_SIZE = 1
for x in IMG_SHAPE:
IMG_SIZE *= x
#########################
## MNIST DATASET
#########################
train_dataset = datasets.MNIST(root='../data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='../data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Checking the dataset
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
# 输出
# Image batch dimensions: torch.Size([128, 1, 28, 28])
# Image label dimensions: torch.Size([128])
##############################
## MODEL
##############################
class GAN(torch.nn.Module):
def __init__(self):
super(GAN, self).__init__()
self.generator = nn.Sequential(
nn.Linear(LATENT_DIM, 128),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(128, IMG_SIZE),
nn.Tanh()
)
self.discriminator = nn.Sequential(
nn.Linear(IMG_SIZE, 128),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(128, 1),
nn.Sigmoid()
)
def generator_forward(self, z):
img = self.generator(z)
return img
def discriminator_forward(self, img):
pred = model.discriminator(img)
return pred.view(-1)
start_time = time.time()
discr_costs = []
gener_costs = []
for epoch in range(num_epochs):
model = model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = (features - 0.5) * 2.
features = features.view(-1, IMG_SIZE).to(device)
targets = targets.to(device)
# Adversarial ground truths
valid = torch.ones(targets.size(0)).float().to(device)
fake = torch.zeros(targets.size(0)).float().to(device)
### FORWARD AND BACK PROP
# ---------------------
# Train Generator
# ---------------------
# make new images
z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device)
# generate a batch of images
generated_features = model.generator_forward(z)
# Loss measures generators's ability to fool the discriminator
discr_pred = model.discriminator_forward(generated_features)
gener_loss = F.binary_cross_entropy(discr_pred, valid)
optim_gener.zero_grad()
gener_loss.backward()
optim_gener.step()
# ---------------------
# Train Discriminator
# ---------------------
# Measure discriminator's ability to classify real from samples
discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE))
real_loss = F.binary_cross_entropy(discr_pred_real, valid)
discr_pred_fake = model.discriminator_forward(generated_features.detach())
fake_loss = F.binary_cross_entropy(discr_pred_fake, fake)
discr_loss = 0.5 * (real_loss + fake_loss)
optim_discr.zero_grad()
discr_loss.backward()
optim_discr.step()
discr_costs.append(discr_loss)
gener_costs.append(gener_loss)
### LOGGING
if not batch_idx % 100:
print('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f'
%(epoch+1, num_epochs, batch_idx, len(train_loader), gener_loss, discr_loss))
print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))