pytorch 7月18日学习---dcgan代码学习2

一.  Dataset处理

1. 导入torch的库

​import torchvision.datasets as dset
import torchvision.transforms as transforms​

2. dest.xxx函数

例如 :

dataset = dset.CIFAR10(root='../data/', download=True, transform=none)

解释 :

将相对目录../data下的cifar-10-batches-py文件夹中的全部数据(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压

补充 :

(1) root,表示cifar10数据的加载的相对目录

(2) train,表示是否加载数据库的训练集,false的时候加载测试集

(3) download,表示是否自动下载cifar数据集

(4) transform,表示是否需要对数据进行预处理,none为不进行预处理

 

3. transform预处理

例如 :

transform=transforms.Compose([  transforms.Scale(opt.imageSize),     
                                transforms.ToTensor(),                                
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]))

解释 :

(1) transform.Scale(size)  定义图片大小

(2) transform.ToTenosr() 

将 PIL.Image/numpy.ndarray 数据转化为torch.FloadTensor,并归一化到[0, 1.0]

(3) transform.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)

通过下面公式实现数据归一化

channel=(channel-mean)/std

 

4. loader

例如 :

loader = torch.utils.data.DataLoader( dataset = dataset, 
                                      batch_size = opt.batchSize,                     
                                      shuffle = True)

解释 :

(1) 第一个参数transformed_dataset,即已经用了transform的Dataset实例。

(2) 第二个参数batch_size,表示每个batch包含多少个数据。

(3) 第三个参数shuffle,布尔型变量,表示是否打乱。

(4) 第四个参数num_workers表示使用几个线程来加载数据

 

5. 数据集补充

 (1) CIFAR-10

https://upload-images.jianshu.io/upload_images/68960-961ca029a7b643c9.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/472

CIFAR-10是多伦多大学提供的图片数据库,图片分辨率压缩至32x32,一共有10种图片分类,均进行了标注。适合监督式学习。

 

6. 完整代码

###############   DATASET   ##################

if(opt.dataset == 'CIFAR'):

    dataset = dset.CIFAR10(root='../data/', download=True,

                                   transform=transforms.Compose([

                                   transforms.Scale(opt.imageSize),

                                   transforms.ToTensor(),

                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

                               ]))

else:

    dataset = dset.MNIST(root = '../data/',

                               transform=transforms.Compose([

                               transforms.Scale(opt.imageSize),

                               transforms.ToTensor(),

                           ]),

                          download = True)

 

 

loader = torch.utils.data.DataLoader(dataset = dataset,

                                     batch_size = opt.batchSize,

                                     shuffle = True)

 

 

二. Model设计

1. Generator

class Generator(nn.Module):

    def __init__(self, nc, ngf, nz):

        super(Generator,self).__init__()

        self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz,ngf*4,kernel_size=4),

                                 nn.BatchNorm2d(ngf*4),

                                 nn.ReLU())

        # 4 x 4

        self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),

                                 nn.BatchNorm2d(ngf*2),

                                 nn.ReLU())

        # 8 x 8

        self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),

                                 nn.BatchNorm2d(ngf),

                                 nn.ReLU())

        # 16 x 16

        self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),

                                 nn.Tanh())

 

    def forward(self,x):

        out = self.layer1(x)

        out = self.layer2(out)

        out = self.layer3(out)

        out = self.layer4(out)

        return out

 

2. Discriminator

class Discriminator(nn.Module):

    def __init__(self,nc,ndf):

        super(Discriminator,self).__init__()

        # 32 x 32

        self.layer1 = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=4,stride=2,padding=1),

                                 nn.BatchNorm2d(ndf),

                                 nn.LeakyReLU(0.2,inplace=True))

        # 16 x 16

        self.layer2 = nn.Sequential(nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1),

                                 nn.BatchNorm2d(ndf*2),

                                 nn.LeakyReLU(0.2,inplace=True))

        # 8 x 8

        self.layer3 = nn.Sequential(nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1),

                                 nn.BatchNorm2d(ndf*4),

                                 nn.LeakyReLU(0.2,inplace=True))

        # 4 x 4

        self.layer4 = nn.Sequential(nn.Conv2d(ndf*4,1,kernel_size=4,stride=1,padding=0),

                                 nn.Sigmoid())

 

    def forward(self,x):

        out = self.layer1(x)

        out = self.layer2(out)

        out = self.layer3(out)

        out = self.layer4(out)

        return out

 

3. model调用

###############   MODEL   ####################

ndf = opt.ndf

ngf = opt.ngf

nc = 1

if(opt.dataset == 'CIFAR'):

    nc = 3

netD = Discriminator(nc, ndf)

netG = Generator(nc, ngf, opt.nz)

if(opt.cuda):

    netD.cuda()

    netG.cuda()

 

源代码网址:https://github.com/sunshineatnoon/Paper-Implementations/tree/master/dcgan

猜你喜欢

转载自blog.csdn.net/weixin_42445501/article/details/81103851