Simple Records: Autoencoders, pytorch + MNIST

Table of contents

introduce

NET

Multilayer Perceptron Edition

Convolution version

train

visualization

test


introduce

Autoencoder (AutoEncoder) was originally used as a data compression method, and its characteristics are as follows:
(1) It has a high degree of correlation with data, which means that autoencoder can only compress data similar to training data, because the features extracted by neural network are generally highly related to the original training set, and the autoencoder trained by using human faces will perform poorly when compressing pictures of animals in nature, because it only learns the characteristics of human faces, but not the characteristics of natural pictures.
(2) The compressed data is lossy, because information is inevitably lost in the process of dimensionality reduction. By 2012, it was discovered that layer-by-layer pre-training using autoencoders in convolutional neural networks can train deeper networks, but people soon discovered that a good initialization strategy is much more effective than complex layer-by-layer pre-training. The Batch Normalization technology that appeared in 2014 also enabled deeper networks to be effectively trained. By the end of 2015, neural networks of any depth can basically be trained through residuals (ResNet).
So now the autoencoder is mainly used in two aspects: the first is data denoising, and the second is visual dimensionality reduction. Autoencoders also have a function, which is to generate data.

This design is a four-layer multi-layer perceptron encoder designed for MNIST handwritten data

NET

Multilayer Perceptron Edition

import torch.nn as nn
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(True),

            nn.Linear(128, 64),
            nn.ReLU(True),

            nn.Linear(64, 12),
            nn.ReLU(True),

            nn.Linear(12, 3)
        )
        self.decodere = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),

            nn.Linear(12, 64),
            nn.ReLU(True),

            nn.Linear(64, 128),
            nn.ReLU(True),

            nn.Linear(128, 28 * 28),

            nn.Tanh(),
        )

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decodere(encoder)
        return encoder, decoder

Convolution version

class DCautoencoder(nn.Module):
    def __init__(self):
        super(DCautoencoder, self).__init__()
        self.encoder = nn.Sequential(       #如果输入的是28 * 28 的图片
            nn.Conv2d(1, 16, 3, 3, 1),      # 10
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),             # 5

            nn.Conv2d(16, 8, 3, 2, 1),      # 3
            nn.ReLU(True),
            nn.MaxPool2d(2, 1)              # 2
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, 2),  # 5
            nn.ReLU(True),

            nn.ConvTranspose2d(16, 8, 5, 3, 1),  # 15
            nn.ReLU(True),

            nn.ConvTranspose2d(8, 1, 2, 2, 1),  # 28

            nn.Tanh()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)

        return encode, decode

train

import torch
import six_Net
import torch.nn as nn
import tqdm
import os

import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torch import optim
from torchvision.utils import save_image

##定义参数
batch_size = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

im_tfs = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_set = mnist.MNIST('./data', train=True, transform=im_tfs, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)

model = six_Net.Autoencoder().to(device)

# x = Variable(torch.randn(1, 28*28))
# code ,_ =model(x)
# print(code.shape)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)       #设定最小为0 ,最大为1
    x = x.view(x.shape[0], 1, 28, 28)
    return x

epoches = 20
##开始训练
print('开始训练')
for epoch in range(epoches):
    for im, a in tqdm.tqdm(train_data):
        im = im.view(im.shape[0], -1)
        im = Variable(im).to(device)

        ##前向传播
        _, out = model(im)
        # loss = criterion(out, im) / im.shape[0]
        loss = criterion(out, im)

        ##反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if (epoch +1 ) % 1 == 0:
        print()
        print(f'训练第 {epoch + 1} 个 epoch 中,loss : {loss}')
        pic = to_img(out.cpu().data)
        if not os.path.exists('./out'):
            os.mkdir('./out')
        save_image(pic, f'./out/image_{epoch + 1}.png')

torch.save(model.state_dict(), './params/encoder.pth')

visualization

# 可视化结果
view_data = Variable((train_set.train_data[:200].type(torch.FloatTensor).view(-1, 28*28) / 255. - 0.5) / 0.5).to(device)
encode, _ = model(view_data)    # 提取压缩的特征值
fig = plt.figure(2)
ax = Axes3D(fig)    # 3D 图
# x, y, z 的数据值
X = encode.data[:, 0].numpy()
Y = encode.data[:, 1].numpy()
Z = encode.data[:, 2].numpy()
values = train_set.train_labels[:200].numpy()  # 标签值
for x, y, z, s in zip(X, Y, Z, values):
    c = cm.rainbow(int(255*s/9))    # 上色
    ax.text(x, y, z, s, backgroundcolor=c)  # 标位子
ax.set_xlim(X.min(), X.max())
ax.set_ylim(Y.min(), Y.max())
ax.set_zlim(Z.min(), Z.max())
plt.show()

test


##开始测试
model.load_state_dict(torch.load('./params/encoder.pth'))
code = Variable(torch.FloatTensor([[1.19, -3.36, -2.06]])).to(device) # 给一个 code 是 (1.19, -3.36, 2.06)
decode = model.decodere(code)
decode_img = to_img(decode).squeeze()
print(decode_img.shape)
decode_img = decode_img.cpu().data.numpy() * 255
plt.imshow(decode_img.astype('uint8'), cmap='gray') # 生成图片 3
plt.show()

Guess you like

Origin blog.csdn.net/qq_42792802/article/details/126127691