MNIST手写数字数据集

from torchvision import datasets, transforms
# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 20

# data_loader
img_size = 64
transform = transforms.Compose([
        transforms.Scale(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# network
G = generator(128)
D = discriminator(128)
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
G.cuda()
D.cuda()
# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

MNIST数据集已经包含在了torchvision里面,从网上搜索到的介绍:图片大小为28x28,训练样本有6000个,测试样本10000个

transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

对于上面这个标准化,normal公式是\widehat{x}=\frac{x-\mu }{\delta }, if \; x\in \left ( 0,1 \right );\widehat{x}\in\left ( -1,1 \right )

初始化:

 # weight_init
def weight_init(self, mean, std):
   for m in self._modules:
       normal_init(self._modules[m], mean, std)

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

猜你喜欢

转载自blog.csdn.net/zz2230633069/article/details/85280785