ドロップアウトの判断は、オーバーフィットの方法に抵抗することができます

ここに写真の説明を挿入
ここに写真の説明を挿入
ここに写真の説明を挿入

脱落

LR=0.5
model = Net()
mse_loss = nn.CrossEntropyLoss()
#定义优化器,设置正则化L2
optimizer=optim.SGD(model.parameters(),LR,weight_decay=0.001)

def train(): #调用一次,训练一个周期
    model.train()  # dropout 起作用
    for i,data in enumerate(train_loader):
        #获得一个批次的数据和标签
        inputs, labels = data
        #(64,10)
        out = model(inputs)
#         #把数据标签变成独热编码
#         #[64]-->[64,1]
#         labels = labels.reshape(-1,1)
#         #tensor.scatter(dim,index,src)
#         #dim:对哪个维度进行独热编码
#         #index:要将src中对用的值放到tensor中的哪个位置
#         #src:插入index的数值
#         #将label转为one_hot编码1-->[1,0,0,0,0,0,0,0,0,0]
#         one_hot = torch.zeros(inputs.shape[0],10).scatter(1,labels,1) #将1放到labels中的哪个位置
        #计算loss
        loss = mse_loss(out,labels)
        #梯度清零
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #更新权值
        optimizer.step()
def test():
    #测试集准确率
    model.eval() #dropout不工作
    correct=0
    for i,data in enumerate(test_loader):
        inputs,labels = data
        out = model(inputs)
        _,predicted = torch.max(out,1)
        correct += (predicted==labels).sum()
    print("Test acc:{0}".format(correct.item()/len(test_dataset)))
    print(correct.item())
    print(len(test_dataset))
    #训练集准确率
    correct1=0
    for i,data in enumerate(train_loader):
        inputs,labels = data
        out = model(inputs)
        _,predicted = torch.max(out,1)
        correct1 += (predicted==labels).sum()
    print("Train acc:{0}".format(correct1.item()/len(train_dataset)))

ドロップアウト:トレーニングセットとテストセットの精度を比較し、差を取ります。比較では、ドロップアウト= 0の差は使用されません。オーバーフィット(差が小さいほど良い)の
正規化に抵抗できると説明できます。オーバーフィットに抵抗できます。

おすすめ

転載: blog.csdn.net/BigData_Mining/article/details/108419062