mxnet实现线性回归(linear regression)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27492735/article/details/82705344

采用mxnet实现线性回归算法

'coding = utf-8'
import mxnet.ndarray as nd
from mxnet import autograd
import random

#数据的生成1000*2
num_inputs = 2
num_examples = 1000

ture_w = [2,-3.4]
ture_b = 4.2

x = nd.random_normal(shape=(num_examples,num_inputs))
y = ture_w[0] * x[:,0] + ture_w[1] * x[:,1] + ture_b
y += .01 * nd.random_normal(shape=y.shape)
print(x[0:10],y[0:10])

#数据读取
batch_size = 10
def data_iter():
    #产生一个随机索引
    index = list(range(num_examples))
    random.shuffle(index)
    for i in range(0,num_examples,batch_size):
        j = nd.array(index[i:min(i + batch_size,num_examples)])
        yield nd.take(x,j),nd.take(y,j)

for data ,label in data_iter():
    print(data,label)
    break

#初始化参数
w = nd.random_normal(shape=(num_inputs,1))
b = nd.zeros((1,))
params = [w,b]
for param in params:
    param.attach_grad()

#定义模型
def net(x):
    return nd.dot(x,w) + b

#损失函数
def square_loss(yhat,y):
    #把y变为yhat避免自动广播
    return (yhat - y.reshape(yhat.shape)) ** 2
#优化
def SGD(params,lr):
    for param in params:
        param[:] = param - lr * param.grad

#训练
epochs = 5
learning_rate = 0.001
for e in range(epochs):
    total_loss = 0
    for data ,label in data_iter():
        with autograd.record():
            output = net(data)
            loss = square_loss(output,label)
        loss.backward()
        SGD(params,learning_rate)
        total_loss += nd.sum(loss).asscalar()
    print('Epoch %d,average loss:%f'%(e,total_loss/num_examples))


print(ture_b,b)#验证学习的效果
print(ture_w,w)

运行结果:


[[ 1.1630785   0.4838046 ]
 [ 0.29956347  0.15302546]
 [-1.1688148   1.558071  ]
 [-0.5459446  -2.3556297 ]
 [ 0.54144025  2.6785064 ]
 [ 1.2546344  -0.54877406]
 [-0.68106437 -0.1353156 ]
 [ 0.37723133  0.41016456]
 [ 0.5712682  -2.7579627 ]
 [ 1.07628    -0.6141325 ]]
<NDArray 10x2 @cpu(0)> 
[ 4.879625   4.2968144 -3.4331114 11.099875  -3.8235688  8.576558
  3.3012044  3.5644817 14.710667   8.445549 ]
<NDArray 10 @cpu(0)>

[[ 0.2508148  -0.30159083]
 [ 0.1703758   1.4159782 ]
 [ 0.60099405 -0.04193413]
 [ 0.5398333  -0.08873118]
 [ 0.5390026   0.15226784]
 [-0.92770946  0.6714997 ]
 [ 0.29670075  1.3111951 ]
 [-0.56099415 -0.9978546 ]
 [-0.64676666  1.3366979 ]
 [ 1.5275036  -1.3969318 ]]
<NDArray 10x2 @cpu(0)> 
[ 5.7448215  -0.26755202  5.538153    5.5854945   4.7548504   0.04230662
  0.34051183  6.4591236  -1.6538469  11.992856  ]
<NDArray 10 @cpu(0)>
Epoch 0,average loss:5.551557
Epoch 1,average loss:0.098503
Epoch 2,average loss:0.001866
Epoch 3,average loss:0.000130
Epoch 4,average loss:0.000097
4.2 
[4.2002177]
<NDArray 1 @cpu(0)>
[2, -3.4] 
[[ 1.9998528]
 [-3.400118 ]]
<NDArray 2x1 @cpu(0)>

Process finished with exit code 0

猜你喜欢

转载自blog.csdn.net/qq_27492735/article/details/82705344