la pérdida del centro implementa el conjunto de datos MNIST (pytorch)

pérdida del centro

import torch as t
import torch.nn as nn
import torch.nn.functional as F

class CenterLoss(nn.Module):
    def __init__(self,cls_num,featur_num):
        super().__init__()

        self.cls_num = cls_num
        self.featur_num=featur_num
        self.center = nn.Parameter(t.rand(cls_num,featur_num))

    def forward(self, xs,ys):   #xs=feature,ys=target
        # xs= F.normalize(xs)
        self.center_exp = self.center.index_select(dim=0,index=ys.long())
        count = t.histc(ys,bins=self.cls_num,min=0,max=self.cls_num-1)
        self.count_dis = count.index_select(dim=0,index=ys.long())+1
        loss = t.sum(t.sum((xs-self.center_exp)**2,dim=1)/2.0/self.count_dis.float())

        return loss

Red

import torch as t
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
import os

Batch_Size = 128
train_data = tv.datasets.MNIST(
    root="MNIST_data",
    train=True,
    download=False,
    transform=tv.transforms.Compose([tv.transforms.ToTensor(),
                                     tv.transforms.Normalize((0.1307,), (0.3081,))]))

test_data = tv.datasets.MNIST(
    root="MNIST_data",
    train=False,
    download=False,
    transform=tv.transforms.Compose([tv.transforms.ToTensor(),
                                     tv.transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = data.DataLoader(train_data, batch_size=Batch_Size, shuffle=True, drop_last=True,num_workers=8)
test_loader = data.DataLoader(test_data, Batch_Size, True, drop_last=True,num_workers=8)

class TrainNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.hidden_layer = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1),
            nn.PReLU(),
            # nn.BatchNorm2d(32),
            nn.Conv2d(32, 128, 3, 2, 1),
            nn.PReLU(),
            # nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.PReLU(),
            # nn.BatchNorm2d(128),
            nn.Conv2d(128, 16,3, 2, 1),
            nn.PReLU())
        self.linear_layer = nn.Linear(16*4*4,2)
        self.output_layer = nn.Linear(2,10)

    def forward(self, xs):
        feat = self.hidden_layer(xs)
        # print(feature.shape)
        fc = feat.reshape(-1,16*4*4)
        # print(fc.data.size())
        feature = self.linear_layer(fc)
        output = self.output_layer(feature)
        return feature, F.log_softmax(output,dim=1)

def decet(feature,targets,epoch,save_path):
    color = ["red", "black", "yellow", "green", "pink", "gray", "lightgreen", "orange", "blue", "teal"]
    cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    plt.ion()
    plt.clf()
    for j in cls:
        mask = [targets == j]
        feature_ = feature[mask].numpy()
        x = feature_[:, 1]
        y = feature_[:, 0]
        label = cls
        plt.plot(x, y, ".", color=color[j])
        plt.legend(label, loc="upper right")     #如果写在plot上面,则标签内容不能显示完整
        plt.title("epoch={}".format(str(epoch)))

    plt.savefig('{}/{}.jpg'.format(save_path,epoch+1))
    plt.draw()
    plt.pause(0.001)







Entrenar

from Net import *
from centerloss import CenterLoss

save_path = r"{}\train{}.pt"
if __name__ == '__main__':
    net = TrainNet()
    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
    centerloss = CenterLoss(10, 2).to(device)
    # crossloss = nn.CrossEntropyLoss().to(device)
    nllloss = nn.NLLLoss().to(device)
    # optmizer = t.optim.Adam(net.parameters())
    optmizer = t.optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
    scheduler = lr_scheduler.StepLR(optmizer, 20, gamma=0.8)
    optmizercenter = t.optim.SGD(centerloss.parameters(), lr=0.5)

    # if os.path.exists(save_path):
    #     net.load_state_dict(t.load(save_path))
    net = net.to(device)
    # write = SummaryWriter("log")
    count = 0
    for epoch in range(1000):
        scheduler.step()
        feat = []
        target = []
        for i, (x, y) in enumerate(train_loader):
            x,y = x.to(device),y.to(device)
            xs,ys = net(x)
            value = t.argmax(ys, dim=1)
            center_loss = centerloss(xs,y)
            nll_loss = nllloss(ys,y)
            # cross_loss = crossloss(ys,y)
            # loss = center_loss+cross_loss
            loss = nll_loss+center_loss
            optmizer.zero_grad()
            optmizercenter.zero_grad()
            loss.backward()
            optmizer.step()
            optmizercenter.step()
            count+=1
            feat.append(xs)    
            target.append(y)
            if i % 100 == 0:
                print(epoch, i, loss.item())
                print(value[0].item(), "========>", y[0].item())
            # if i %500==0:
            #     t.save(net.state_dict(),save_path.format(r"D:\PycharmProjects\center_loss\data",str(count)))
        features = t.cat(feat,0)
        targets = t.cat(target,0)
        decet(features.data.cpu(),targets.data.cpu(), epoch,)
        #     write.add_histogram("loss",loss.item(),count)
        # write.close()

Show de efectos

Utilice NLLloss, SGD para aumentar el impulso y actualizar la tasa de aprendizaje
La pérdida del cenlter normaliza la función de entrada
Optimizador de Adam

 

La red principal usa BN y el sesgo de la capa de salida = Falso

Resumen del proceso de optimización:

  1. El efecto de elegir NLLloss es mejor que CrossEntropyLoss, nllloss = log () + nllloss ()
  2. La pérdida central y la red se optimizan por separado, el efecto será mejor y la velocidad será más rápida (tasa de aprendizaje de pérdida central = 0,5)
  3. Cuando se utiliza la optimización SGD, si no se agrega impulso, comenzará a fallar (difícil) converger en unas treinta rondas. Si el impulso solo se aumenta sin actualizar artificialmente la tasa de aprendizaje, la velocidad de convergencia será muy lenta;
  4. Cuando se utiliza la optimización de Adam, la velocidad es más rápida que SGD, pero el efecto no es bueno;
  5. Partido final: NLLLOSS + SGD optmizer (momentum + lr updated)
  6. Con respecto a la red, el efecto de la convolución es ligeramente mejor que el de la conexión completa y el efecto de un diseño de red más grande es mejor.
  7. En el proceso de dibujar puntos, si los datos no se cargan por adelantado, tomará mucho tiempo dibujar los puntos; si los datos no se borran, el dibujo se volverá cada vez más lento; demasiados puntos pueden producir efectos insignificantes (hazaña = [en el código ], target = [] fuera de lugar)

Supongo que te gusta

Origin blog.csdn.net/weixin_45191152/article/details/97762005
Recomendado
Clasificación