Pytorch implementiert den AutoEncoder-Autoencoder

1.Konzepte und Prinzipien

Autoencoder (AE) ist eine Art künstlicher neuronaler Netze (ANNs), die beim halbüberwachten und unüberwachten Lernen verwendet werden. Seine Funktion besteht darin, Repräsentationslernen auf den Eingabeinformationen durchzuführen, indem die Eingabeinformationen als Lernziel verwendet werden (Repräsentationslernen).

Das algorithmische Modell besteht aus zwei Hauptteilen:Encoder(编码器)和Decoder(解码器).

Die Funktion des Encoders besteht darin, die hochdimensionale Eingabe Das Netzwerk lernt die aussagekräftigsten quantitativen Eigenschaften. Die Funktion des Decoders besteht darin, die latente Variable h der verborgenen Schicht auf die ursprüngliche Dimension wiederherzustellen. Der beste Zustand ist, dass die Ausgabe des Decoders perfekt sein kann. Stellen Sie die ursprüngliche Eingabe genau wieder her oder etwa. Eine Reduzierung der Dimensionalität kanndurch AE erreicht werden..

2. Code-Implementierung

AE_main.py (Hauptfunktion und Anzahl der Iterationen definieren)

import  torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from torch import nn,optim
from AutoEncoder import AE

import visdom

def main():
    mnist_train=datasets.MNIST('mnist',True,transform=transforms.Compose([
        transforms.ToTensor()
    ]),download=True)
    mnist_train=DataLoader(mnist_train,batch_size=32,shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x,_=iter(mnist_train).next()
    print('x:',x.shape)

    model=AE()

    criteon=nn.MSELoss()
    optimizer=optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    viz=visdom.Visdom()

    for epoch in range(1000):
        for batch_size,(x,_) in enumerate(mnist_train):
            x_hat=model(x)
            loss=criteon(x_hat,x)

            #backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch,'loss:',loss.item())
        x,_=iter(mnist_test).next() #其中x是标签,_是label
        with torch.no_grad():
            x_hat=model(x)
        viz.images(x,nrow=8,win='x',opts=dict(title='x'))
        viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()

 AutoEncoder.py (Netzwerkarchitektur definieren)

import torch
from torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        #[b,784]=>[b,20]
        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,20),
            nn.ReLU()
        )
        #[b,20]=>[b,784]
        self.decoder=nn.Sequential(
            nn.Linear(20,64),
            nn.ReLU(),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()
        )
    def forward(self,x):
        batch_size=x.size(0)
        #flatten
        x=x.view(batch_size,784)
        #encoder
        x=self.encoder(x)
        # decoder
        x=self.decoder(x)
        #reshape
        x=x.view(batch_size,1,28,28)
        return x

3. Vorsichtsmaßnahmen

Bevor Sie die Datei AE_main.py ausführen, müssen Sie den folgenden Code auf der Konsole ausführen und visdom öffnen

python -m visdom.server

 

Öffnen Sie diese http://localhost:8097-Website, um Bilder des Trainingsprozesses anzusehen.  

4. Operationsergebnisse 

 

 

 

 

 

Supongo que te gusta

Origin blog.csdn.net/weixin_41477928/article/details/127012686
Recomendado
Clasificación