Pytorch实现Autoencoder

“`

“`import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import skimage.io as io
from torchvision.utils import save_image

Hyper Parameters

EPOCH = 100
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_CIFAR10 = True

Mnist digits sets

train_data = torchvision.datasets.CIFAR10(
root=’./cifar10/’,
train=False,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_CIFAR10,
)

train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

class AutoEncoder(torch.nn.Module):
def init(self):
super(AutoEncoder,self).init()

    self.encoder = nn.Sequential(

        nn.Conv2d(
            in_channels=3,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=1,
        ),#32*32*3->16*16*64
        nn.LeakyReLU(),

        nn.Conv2d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            stride=2,
            padding=1,
        ),
        nn.BatchNorm2d(num_features=128,affine=True),  #num_features=batch_size x num_features [x width]
        nn.LeakyReLU(),

        nn.Conv2d(
            in_channels=128,
            out_channels=256,
            kernel_size=3,
            stride=2,
            padding=1,
        ),
        nn.BatchNorm2d(num_features=256,affine=True),
        nn.LeakyReLU(),
    )

    self.decoder = nn.Sequential(

        nn.ConvTranspose2d(
            in_channels=256,
            out_channels=128,
            kernel_size=4,
            stride=2,
            padding=1,
        ),
        nn.BatchNorm2d(num_features=128,affine=True),
        nn.LeakyReLU(),

        nn.ConvTranspose2d(
            in_channels=128,
            out_channels=64,
            kernel_size=4,
            stride=2,
            padding=1,
        ),
        nn.BatchNorm2d(num_features=64,affine=True),
        nn.LeakyReLU(),

        nn.ConvTranspose2d(
            in_channels=64,
            out_channels=3,
            kernel_size=4,
            stride=2,
            padding=1,
        ),
        nn.LeakyReLU(),

    )


def forward(self, x):
    x = x.view(-1, 3, 32, 32)
    encode = self.encoder(x)

    decode = self.decoder(encode)
    return encode,decode

autoencoder = AutoEncoder()

optimizer = torch.optim.Adam(autoencoder.parameters(),lr=LR)
loss_func = nn.MSELoss()

for epoch in range(EPOCH):
for step,(x,y) in enumerate(train_loader):
b_x = Variable(x.view(-1, 3,32,32)) # batch x, shape (batch, 32*32*3)
b_y = b_x.detach() # batch y, shape (batch, 32*32*3)
# b_label = Variable(y) # batch label

    encoded,decoded = autoencoder(b_x)

    if step%100 == 0:
        img_to_save = decoded.data
        save_image(img_to_save,'res/%s-%s.jpg'%(epoch,step))
    # io.imsave('.xxx.jpg',img_to_save[0])

    # print('wwwwww')
    # print(type(decoded))
    # print(type(b_y))
    # b_y.type_as(decoded)
    loss = loss_func(decoded, b_y)      # mean square error
    optimizer.zero_grad()               # clear gradients for this training step
    loss.backward()                     # backpropagation, compute gradients
    optimizer.step()

    if step%100 == 0:
        print('Epoch:',epoch,'|','train loss %.4f:' % loss.data[0])

猜你喜欢

转载自blog.csdn.net/qq_38213612/article/details/78906738