classification(分类)

ps:上一篇博客主要讨论了关于回归的问题,这一节课主要讨论分类问题,其实网络的结构大同小异只是在数据集 上有一些差异。

代码

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
#make fake data
#创建数据集
"""首先创建一个100行两列的全是1 的数据,
然后使用normal正态分布函数,第一个参数代表均值,第二个参数代表方差"""
n_data=torch.ones(100,2)
x0=torch.normal(2*n_data,1)
y0=torch.zeros(100)
x1=torch.normal(-2*n_data,1)
y1=torch.ones(100)
#cat(seq,dim,out=None) 其中表示要连接的两个序列,以元组的形式给出
#dim 表示在哪个维度连接,dim=0,横向连接,dim=1,纵向连接
#https://blog.csdn.net/xrinosvip/article/details/81164697详细请参考
x=torch.cat((x0,x1),0).type(torch.FloatTensor)
y=torch.cat((y0,y1),0).type(torch.LongTensor)
x,y=Variable(x),Variable(y)
# plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],
#             c=y.data.numpy(),s=100,lw=0,cmap='RdYlGn')
# plt.show()

#build neural network
#method1  跟上一个博客中的方法相同
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(n_feature,n_hidden)
        self.predict=torch.nn.Linear(n_hidden,n_output)
    def forward(self,x):
        x=F.relu(self.hidden(x))
        x=self.predict(x)
        return x
#输入的是两个特征,输出的类别也是两个
net=Net(2,10,2)
#method 2 这个是快速搭建神经网络的方法
net2=torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2)

)
print net
print net2
#train net
optimizer=torch.optim.SGD(net.parameters(),lr=0.02)
#对于回归问题我们一般使用MSEloss()作为损失函数,在分类问题中使用CrossEntropyLoss()
loss_func=torch.nn.CrossEntropyLoss()
#开始画图
plt.ion()
plt.show()
for t in range(100):
    out=net(x)
    loss=loss_func(out,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 2 == 0:
        plt.cla()
        #因为out输出来并不是概率,所以需要使用激励函数softmax把他转换成概率的形式
        #比如说对于三分类问题可能返回[0.1,0.2,0.7]
        #三者的和为1这个就表示被分为第三类的可能性最大
        #torch .max()详细参考https://blog.csdn.net/Z_lbj/article/details/79766690
        prediction=torch.max(F.softmax(out),1)[1]
        pred_y=prediction.data.numpy().squeeze()
        target_y=y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1],
                    c=pred_y,s=100,lw=0,cmap='RdYlGn')


        # accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
        accuracy=float(sum(pred_y==target_y))/200
        plt.text(1.5,-4,'Accuracy=%.2f'% accuracy,fontdict={'size':20,'color':'red'})
        plt.pause(0.5)
plt.ioff()
plt.show()


猜你喜欢

转载自blog.csdn.net/xs_211314/article/details/82498608