“`
“`import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import skimage.io as io
from torchvision.utils import save_image
Hyper Parameters
EPOCH = 100
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_CIFAR10 = True
Mnist digits sets
train_data = torchvision.datasets.CIFAR10(
root=’./cifar10/’,
train=False,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_CIFAR10,
)
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
class AutoEncoder(torch.nn.Module):
def init(self):
super(AutoEncoder,self).init()
self.encoder = nn.Sequential(
nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
),#32*32*3->16*16*64
nn.LeakyReLU(),
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
padding=1,
),
nn.BatchNorm2d(num_features=128,affine=True), #num_features=batch_size x num_features [x width]
nn.LeakyReLU(),
nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=2,
padding=1,
),
nn.BatchNorm2d(num_features=256,affine=True),
nn.LeakyReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=4,
stride=2,
padding=1,
),
nn.BatchNorm2d(num_features=128,affine=True),
nn.LeakyReLU(),
nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=4,
stride=2,
padding=1,
),
nn.BatchNorm2d(num_features=64,affine=True),
nn.LeakyReLU(),
nn.ConvTranspose2d(
in_channels=64,
out_channels=3,
kernel_size=4,
stride=2,
padding=1,
),
nn.LeakyReLU(),
)
def forward(self, x):
x = x.view(-1, 3, 32, 32)
encode = self.encoder(x)
decode = self.decoder(encode)
return encode,decode
autoencoder = AutoEncoder()
optimizer = torch.optim.Adam(autoencoder.parameters(),lr=LR)
loss_func = nn.MSELoss()
for epoch in range(EPOCH):
for step,(x,y) in enumerate(train_loader):
b_x = Variable(x.view(-1, 3,32,32)) # batch x, shape (batch, 32*32*3)
b_y = b_x.detach() # batch y, shape (batch, 32*32*3)
# b_label = Variable(y) # batch label
encoded,decoded = autoencoder(b_x)
if step%100 == 0:
img_to_save = decoded.data
save_image(img_to_save,'res/%s-%s.jpg'%(epoch,step))
# io.imsave('.xxx.jpg',img_to_save[0])
# print('wwwwww')
# print(type(decoded))
# print(type(b_y))
# b_y.type_as(decoded)
loss = loss_func(decoded, b_y) # mean square error
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step()
if step%100 == 0:
print('Epoch:',epoch,'|','train loss %.4f:' % loss.data[0])