基于python实现的生成对抗网络GAN

项目简介

这篇文章主要介绍了生成对抗网络(Generative Adversarial Network),简称 GAN。

GAN 可以看作是一种可以生成特定分布数据的模型。

2.生成人脸图像

下面的代码是使用 Generator 来生成人脸图像,Generator 已经训练好保存在 pkl 文件中,只需要加载参数即可。由于模型是在多 GPU 的机器上训练的,因此加载参数后需要使用remove_module()函数来修改state_dict中的key

def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

把随机的高斯噪声输入到模型中,就可以得到人脸输出,最后进行可视化。全部代码如下:

import os
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from common_tools i

猜你喜欢

转载自blog.csdn.net/qq_38735017/article/details/128890180