深度学习之PyTorch---- Logistic回归(二分类问题)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_36022260/article/details/83547259
"""
Logistic 回归
"""


class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression,self).__init__()
        self.lr = nn.Linear(2,1)
        self.sm = nn.Sigmoid()
    
    def forward(self,x):
        x = self.lr(x)
        x = self.sm(x)
        return x


if __name__ == '__main__':
    with open('data.txt' , 'r') as f:
        data_list = f.readlines()
        data_list = [i.split('\n')[0] for i in data_list]
        data_list = [i.split(',') for i in data_list]
        data = [(float(i[0]),float(i[1]),float(i[2])) for i in data_list]
        data = torch.Tensor(data)

    
    logistic_model = LogisticRegression()

    # 定义损失函数和优化器
    criterion = nn.BCELoss()
    optimizer = optim.SGD(logistic_model.parameters(),lr=1e-3,momentum=0.9)

    start = time.time()
    for epoch in range(10000):
        x = Variable(data[:,0:2])
        y = Variable(data[:,2]).unsqueeze(1)
       # forward 
        output = logistic_model(x)
        loss = criterion(output,y)

        print_loss = loss.data.item()
        mask = output.ge(0.5).float()
        correct = (mask == y).sum()

        accuracy = correct.item() / x.size(0)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch+1) % 1000 == 0:
            print('epoch {} loss is {:.4f} accuracy is {:.4f}\n'.format(epoch+1 , print_loss,accuracy))
            print('*'*53)
    during = time.time() - start
    print('During time {:.2f}'.format(during))

    w0 , w1 = logistic_model.lr.weight[0]

    w0 = w0.item()
    w1 = w1.item()

    b = logistic_model.lr.bias.item()

    plot_x = np.arange(30,100,0.1)
    plot_y = (- w0 * plot_x - b) / w1
    #print(plot_y)
    plt.plot(plot_x,plot_y)

    x0 = list(filter(lambda x: x[-1] == 0.0,data))
    x1 = list(filter(lambda x: x[-1] == 1.0,data))

    plot_x0_x = [i[0] for i in x0]
    plot_x0_y = [i[1] for i in x0]

    plot_x1_x = [i[0] for i in x1]
    plot_x1_y = [i[1] for i in x1]

    plt.plot(plot_x0_x,plot_x0_y,'ro',label='x_0')
    plt.plot(plot_x1_x,plot_x1_y,'bo',label='x_1')
    plt.legend(loc = 'upper right')
    plt.show()

输出:

epoch 1000 loss is 0.6216 accuracy is 0.6000

*****************************************************
epoch 2000 loss is 0.5781 accuracy is 0.6100

*****************************************************
epoch 3000 loss is 0.5414 accuracy is 0.6600

*****************************************************
epoch 4000 loss is 0.5104 accuracy is 0.6700

*****************************************************
epoch 5000 loss is 0.4840 accuracy is 0.7700

*****************************************************
epoch 6000 loss is 0.4614 accuracy is 0.8000

*****************************************************
epoch 7000 loss is 0.4419 accuracy is 0.8300

*****************************************************
epoch 8000 loss is 0.4250 accuracy is 0.8600

*****************************************************
epoch 9000 loss is 0.4101 accuracy is 0.8900

*****************************************************
epoch 10000 loss is 0.3970 accuracy is 0.9000

*****************************************************
During time 6.23

猜你喜欢

转载自blog.csdn.net/qq_36022260/article/details/83547259