pytorch 7月22日学习---cgan代码学习1

一. 参数

parser = argparse.ArgumentParser()

parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')

parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')

parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')

parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')

parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')

parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')

parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')

parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')

parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')

parser.add_argument('--channels', type=int, default=1, help='number of image channels')

parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')

opt = parser.parse_args()

print(opt)

 

二. 生成器

class Generator(nn.Module):

    def __init__(self):

        super(Generator, self).__init__()


        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)


        def block(in_feat, out_feat, normalize=True):

            layers = [  nn.Linear(in_feat, out_feat)]

            if normalize:

                layers.append(nn.BatchNorm1d(out_feat, 0.8))

            layers.append(nn.LeakyReLU(0.2, inplace=True))

            return layers


        self.model = nn.Sequential(

            *block(opt.latent_dim+opt.n_classes, 128, normalize=False),

            *block(128, 256),

            *block(256, 512),

            *block(512, 1024),

            nn.Linear(1024, int(np.prod(img_shape))),

            nn.Tanh()

        )


    def forward(self, noise, labels):

        # Concatenate label embedding and image to produce input

        gen_input = torch.cat((self.label_emb(labels), noise), -1)

        img = self.model(gen_input)

        img = img.view(img.size(0), *img_shape)

        return img

分析:

一些函数:

1. nn.embeding( )

>>> import torch

>>> import torch.nn as nn

>>> from torch.autograd import Variable

>>> label=nn.Embedding(10,10)

>>> label(Variable(torch.LongTensor([4])))

tensor([[ 1.1337, -0.1294,  0.5304, -0.2917, -0.6218, -0.2924, -0.8620,

          1.1923, -1.1058, -0.1389]])

2. np.prod( )

连乘函数

3.torch.cat( )

>>> x = torch.randn(2, 3)

>>> x

tensor([[ 0.4090,  1.2366, -1.1015],

        [ 1.3993, -2.5494,  0.1595]])

>>> y = torch.randn(2, 3)

>>> y

tensor([[-1.0984,  0.8635, -0.6777],

        [-1.0466, -0.9148,  0.7278]])

>>> torch.cat((x, y), 0)

tensor([[ 0.4090,  1.2366, -1.1015],

        [ 1.3993, -2.5494,  0.1595],

        [-1.0984,  0.8635, -0.6777],

        [-1.0466, -0.9148,  0.7278]])

>>> torch.cat((x, y), 1)

tensor([[ 0.4090,  1.2366, -1.1015, -1.0984,  0.8635, -0.6777],

        [ 1.3993, -2.5494,  0.1595, -1.0466, -0.9148,  0.7278]])

>>> torch.cat((x, y), -1)

tensor([[ 0.4090,  1.2366, -1.1015, -1.0984,  0.8635, -0.6777],

        [ 1.3993, -2.5494,  0.1595, -1.0466, -0.9148,  0.7278]])

torch.cat( ( ) ,0 ) 是竖着连接

torch.cat( ( ) ,1 or-1 ) 是横着连接

三. 辨别器

class Discriminator(nn.Module):

    def __init__(self):

        super(Discriminator, self).__init__()


        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)


        self.model = nn.Sequential(

            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 512),

            nn.Dropout(0.4),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 512),

            nn.Dropout(0.4),

            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1)

        )


    def forward(self, img, labels):

        # Concatenate label embedding and image to produce input

        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)

        validity = self.model(d_in)

        return validity

分析:

源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cgan

猜你喜欢

转载自blog.csdn.net/weixin_42445501/article/details/81172208