通过卷积变换实现自动编码器模型

import torch
import torchvision
from torchvision import datasets,transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
dataset_train = datasets.MNIST(root="./data",transform=transform,train=True,download=True)
dataset_test = datasets.MNIST(root="./data",train=False,transform=transform,download=True)
train_load = DataLoader(dataset=dataset_train,batch_size=64,shuffle=True)
test_load = DataLoader(dataset=dataset_test,batch_size=64,shuffle=False)
images,label = next(iter(train_load))
images_example = torchvision.utils.make_grid(images)
images_example = images_example.numpy().transpose(1,2,0)
mean = [0.5]
std = [0.5]
images_example = images_example * std + mean
plt.imshow(images_example)
print(images_example.shape)
noisy_images=images_example+0.5*np.random.randn(*images_example.shape) 
noisy_images =np.clip(noisy_images, 0., 1.)
class AutoEncode(torch.nn.Module):
    def __init__(self):
        super(AutoEncode,self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1,64,3,1,1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(64,128,3,1,1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2,mode="nearest"),
            torch.nn.Conv2d(128,64,3,1,1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2,mode="nearest"),
            torch.nn.Conv2d(64,1,3,1,1)
        )
    def forward(self,input):
        output = self.encoder(input)
        output = self.decoder(output)
        return output
    
model = AutoEncode()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters())
loss_f = torch.nn.MSELoss()
epoch_n = 10
for epoch in range(epoch_n):
    running_loss = 0.0
    print("Epoch {}/{}".format(epoch+1,epoch_n))
    print("-"*10)
    for i,(x_train,label)  in enumerate(train_load):
        label = label.to(device)
        noisy_x_train = x_train + 0.5*torch.randn(x_train.shape)
        noisy_x_train = torch.clamp(noisy_x_train,0.,1.)
        x_train = x_train.to(device)
        noisy_x_train = noisy_x_train.to(device)
        x_train,noisy_x_train = Variable(x_train),Variable(noisy_x_train)
        train_pre = model(noisy_x_train)
        loss = loss_f(train_pre,x_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss
    print("Loss is:{:.4f}".format(running_loss/len(dataset_train)))

猜你喜欢

转载自blog.csdn.net/qq_43607118/article/details/129541136