深度学习之PyTorch---- 一维线性回归

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_36022260/article/details/83547180
#  一维线性回归的代码实现
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],
                    [9.779],[6.182],[7.59],[2.167],[7.042],
                    [10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],
                    [3.366],[2.596],[2.53],[1.221],[2.827],
                    [3.465],[1.65],[2.904],[1.3]],dtype=np.float32)

x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1,1)
        
    def forward(self,x):
        out = self.linear(x)
        
        return out

model = LinearRegression()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-3)

num_epochs = 2000
for epoch in range(num_epochs):
    inputs = Variable(x_train)
    target = Variable(y_train)
    
    # forward
    out = model(inputs)
    
    loss = criterion(out,target)
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        print('Epoch [{} / {}],loss {}'.format(epoch+1,num_epochs,loss.data[0]))
        
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()

plt.plot(x_train.numpy() , y_train.numpy() ,'ro',label="Original data")
plt.plot(x_train.numpy(),predict,'b',label="Fitting data")
plt.legend()

输出:

Epoch [100 / 2000],loss 0.23131796717643738
Epoch [200 / 2000],loss 0.22819313406944275
Epoch [300 / 2000],loss 0.22522489726543427
Epoch [400 / 2000],loss 0.22240526974201202
Epoch [500 / 2000],loss 0.2197268307209015
Epoch [600 / 2000],loss 0.2171824872493744
Epoch [700 / 2000],loss 0.21476560831069946
Epoch [800 / 2000],loss 0.2124696969985962
Epoch [900 / 2000],loss 0.21028876304626465
Epoch [1000 / 2000],loss 0.2082170695066452
Epoch [1100 / 2000],loss 0.20624907314777374
Epoch [1200 / 2000],loss 0.20437967777252197
Epoch [1300 / 2000],loss 0.20260383188724518
Epoch [1400 / 2000],loss 0.20091693103313446
Epoch [1500 / 2000],loss 0.19931448996067047
Epoch [1600 / 2000],loss 0.19779227674007416
Epoch [1700 / 2000],loss 0.19634634256362915
Epoch [1800 / 2000],loss 0.19497279822826385
Epoch [1900 / 2000],loss 0.19366800785064697
Epoch [2000 / 2000],loss 0.1924285888671875

      

猜你喜欢

转载自blog.csdn.net/qq_36022260/article/details/83547180
今日推荐