python3 5.用tensorflow进行简单的线性回归 学习笔记

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/mcyJacky/article/details/85979960

前言

     计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理。本系列所有代码是用python3编写,在平台Anaconda中运行实现,在使用代码时,默认你已经安装相关的python库,这方面不做多余的说明。本系列所涉及的所有代码和资料可在我的github上下载到,gitbub地址:https://github.com/mcyJacky/DeepLearning-CV,如有问题,欢迎指出。

一、构建线性模型数据

     下面通过随机构造一些噪声点来构建线性模型数据:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

x_data = np.random.rand(100)
# 噪声点
noise = np.random.normal(0,0.01,x_data.shape)
# 模型值
y_data = x_data*0.1 + 0.2 + noise

# 显示
plt.scatter(x_data, y_data)
plt.show()

     模型结果显示如下图1.1所示:

图1.1 线性模型数据

二、构建线性回归线

     下面通过tensorflow进行数据训练来构建以上线性模拟数据的线性回归线:

# 回归线的截距
d = tf.Variable(np.random.rand(1))
# 回归线的斜率
k = tf.Variable(np.random.rand(1))
# 构建一个线性模型
y = k*x_data + d

# 二次代价函数
loss = tf.losses.mean_squared_error(y_data, y)
# 定义一个梯度下降法优化器
optimizer = tf.train.GradientDescentOptimizer(0.3)
# 最小化代价函数
train = optimizer.minimize(loss)

# 初始化变量
init= tf.global_variables_initializer()

# 定义会话
with tf.Session() as sess:
    sess.run(init)
    # 模型训练
    for i in range(201):
        sess.run(train)
        if i%20==0:
            print(i, sess.run([k,d]))
        # 预测值
        y_pred = sess.run(y)
        # 训练结果绘图
        plt.scatter(x_data,y_data)
        plt.plot(x_data,y_pred,'r-',lw=3)
    	plt.show()

# 打印结果:
# 0 [array([0.42558686]), array([0.07772181])]
# 20 [array([0.24686251]), array([0.1212207])]
# 40 [array([0.17103131]), array([0.16282419])]
# 60 [array([0.13410329]), array([0.18308412])]
# 80 [array([0.1161202]), array([0.19295024])]
# 100 [array([0.10736286]), array([0.1977548])]
# 120 [array([0.10309823]), array([0.20009452])]
# 140 [array([0.10102146]), array([0.2012339])]
# 160 [array([0.10001012]), array([0.20178875])]
# 180 [array([0.09951763]), array([0.20205895])]
# 200 [array([0.09927779]), array([0.20219054])]

     由训练打印结果可知,通过模型训练回归线的斜率和截距近似于我们定义线性模型数据的斜率和截距。模型训练结果显示如下图2.1所示,由线性回归结果红色线可直观看出,回归拟合效果较好。

图2.1 线性回归结果

【参考】:
     1. 城市数据团课程《AI工程师》计算机视觉方向
     2. deeplearning.ai 吴恩达《深度学习工程师》
     3. 《机器学习》作者:周志华
     4. 《深度学习》作者:Ian Goodfellow


转载声明:
版权声明:非商用自由转载-保持署名-注明出处
署名 :mcyJacky
文章出处:https://blog.csdn.net/mcyJacky

猜你喜欢

转载自blog.csdn.net/mcyJacky/article/details/85979960