版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_36022260/article/details/83547259
"""
Logistic 回归
"""
class LogisticRegression(nn.Module):
def __init__(self):
super(LogisticRegression,self).__init__()
self.lr = nn.Linear(2,1)
self.sm = nn.Sigmoid()
def forward(self,x):
x = self.lr(x)
x = self.sm(x)
return x
if __name__ == '__main__':
with open('data.txt' , 'r') as f:
data_list = f.readlines()
data_list = [i.split('\n')[0] for i in data_list]
data_list = [i.split(',') for i in data_list]
data = [(float(i[0]),float(i[1]),float(i[2])) for i in data_list]
data = torch.Tensor(data)
logistic_model = LogisticRegression()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(logistic_model.parameters(),lr=1e-3,momentum=0.9)
start = time.time()
for epoch in range(10000):
x = Variable(data[:,0:2])
y = Variable(data[:,2]).unsqueeze(1)
# forward
output = logistic_model(x)
loss = criterion(output,y)
print_loss = loss.data.item()
mask = output.ge(0.5).float()
correct = (mask == y).sum()
accuracy = correct.item() / x.size(0)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 1000 == 0:
print('epoch {} loss is {:.4f} accuracy is {:.4f}\n'.format(epoch+1 , print_loss,accuracy))
print('*'*53)
during = time.time() - start
print('During time {:.2f}'.format(during))
w0 , w1 = logistic_model.lr.weight[0]
w0 = w0.item()
w1 = w1.item()
b = logistic_model.lr.bias.item()
plot_x = np.arange(30,100,0.1)
plot_y = (- w0 * plot_x - b) / w1
#print(plot_y)
plt.plot(plot_x,plot_y)
x0 = list(filter(lambda x: x[-1] == 0.0,data))
x1 = list(filter(lambda x: x[-1] == 1.0,data))
plot_x0_x = [i[0] for i in x0]
plot_x0_y = [i[1] for i in x0]
plot_x1_x = [i[0] for i in x1]
plot_x1_y = [i[1] for i in x1]
plt.plot(plot_x0_x,plot_x0_y,'ro',label='x_0')
plt.plot(plot_x1_x,plot_x1_y,'bo',label='x_1')
plt.legend(loc = 'upper right')
plt.show()
输出:
epoch 1000 loss is 0.6216 accuracy is 0.6000
*****************************************************
epoch 2000 loss is 0.5781 accuracy is 0.6100
*****************************************************
epoch 3000 loss is 0.5414 accuracy is 0.6600
*****************************************************
epoch 4000 loss is 0.5104 accuracy is 0.6700
*****************************************************
epoch 5000 loss is 0.4840 accuracy is 0.7700
*****************************************************
epoch 6000 loss is 0.4614 accuracy is 0.8000
*****************************************************
epoch 7000 loss is 0.4419 accuracy is 0.8300
*****************************************************
epoch 8000 loss is 0.4250 accuracy is 0.8600
*****************************************************
epoch 9000 loss is 0.4101 accuracy is 0.8900
*****************************************************
epoch 10000 loss is 0.3970 accuracy is 0.9000
*****************************************************
During time 6.23