tensorflow 1.x practical tutorial (5) - nonlinear regression model.md

Target

This article aims to introduce the introductory knowledge points and practical examples of tensorflow. I hope that all novice students can be proficient in tensorflow related operations after learning.

Building a Simple Nonlinear Regression Model

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-0.5, 0.5, 200)
x = x[:, np.newaxis] # 插入新维度,变成二维数组
noise = np.random.normal(0, 0.02, x.shape)
y = np.square(x) + noise # 真实的结果

# 输入和输出
input = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='input')
label = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='label') 

# 隐藏层 output1 = tanh(input * w1 + b1)
w1 = tf.Variable(tf.random_normal([1, 10]))
b1 = tf.Variable(tf.zeros([10]))
r1 = tf.matmul(input, w1) + b1
output1 = tf.nn.tanh(r1) 

# 输出层 output2 = tanh(outpu1 * w2 + b2)
w2 = tf.Variable(tf.random_normal([10, 1]))
b2 = tf.Variable(tf.zeros([1]))
r2 = tf.matmul(output1, w2) + b2
output2 = tf.nn.tanh(r2)

loss = tf.reduce_mean(tf.square(output2 - y))
opt = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    # 各初始值
    print('w1:', sess.run(w1))
    print('b1:', sess.run(b1))
    print('w2:', sess.run(w2))
    print('b2:', sess.run(b2))
    print('loss_value:', sess.run(loss, feed_dict = {input:x, label:y}))
    print('----------')
    for epoch in range(10000):
        loss_value, _ = sess.run([loss, opt], feed_dict = {input:x, label:y})
        if epoch%100 == 0:
            print(epoch, loss_value)
    predict_value = sess.run(output2, feed_dict = {input:x, label:y})

# 将真实值和预测值都放入图中观察
plt.figure()
plt.scatter(x, y)
plt.plot(x, predict_value, "r-", lw=3) # 图中的红线是预测的结果
plt.show()
复制代码

output result

w1: [[ 0.30529675 -1.1716157   0.14365277 -1.0401442   1.4598304  -1.1412098
   0.00497767 -1.998712   -0.4879      0.9146647 ]]
b1: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
w2: [[ 1.0205648 ]
 [ 0.585101  ]
 [ 1.549602  ]
 [ 0.08759285]
 [-1.4181321 ]
 [-2.0843756 ]
 [-2.0614386 ]
 [ 0.7541795 ]
 [-0.20926061]
 [-0.8546703 ]]
b2: [0.]
loss_value: 0.21946013
----------
0 0.21946013
100 0.05402036
200 0.005248788
300 0.00079003634
400 0.000587407
500 0.0005625391
600 0.0005441152
700 0.0005278867
800 0.00051349285
900 0.00050072506
1000 0.0004894079
1100 0.00047938176
1200 0.00047050777
1300 0.00046265605
1400 0.00045571418
1500 0.00044957816
1600 0.00044415714
1700 0.00043936947
1800 0.00043514176
1900 0.0004314091
2000 0.00042811292
2100 0.00042520318
2200 0.0004226335
2300 0.00042036377
2400 0.00041835755
2500 0.00041658382
2600 0.0004150144
2700 0.00041362445
2800 0.00041239194
2900 0.00041129842
3000 0.00041032647
3100 0.00040946086
3200 0.000408689
3300 0.0004079996
3400 0.0004073815
3500 0.0004068272
3600 0.0004063281
3700 0.0004058769
3800 0.00040546828
3900 0.00040509747
4000 0.00040475905
4100 0.00040444918
4200 0.00040416396
4300 0.0004039006
4400 0.00040365654
4500 0.00040342903
4600 0.00040321655
4700 0.00040301704
4800 0.0004028289
4900 0.0004026504
5000 0.0004024807
5100 0.00040231895
5200 0.00040216369
5300 0.00040201354
5400 0.0004018695
5500 0.00040172943
5600 0.00040159366
5700 0.00040146164
5800 0.0004013326
5900 0.00040120617
6000 0.0004010826
6100 0.00040096045
6200 0.00040084048
6300 0.00040072255
6400 0.0004006055
6500 0.00040049048
6600 0.00040037613
6700 0.00040026326
6800 0.00040015127
6900 0.00040003992
7000 0.0003999297
7100 0.0003998201
7200 0.00039971116
7300 0.0003996028
7400 0.00039949492
7500 0.00039938747
7600 0.00039928075
7700 0.00039917394
7800 0.00039906823
7900 0.00039896235
8000 0.00039885714
8100 0.00039875193
8200 0.00039864716
8300 0.00039854273
8400 0.00039843933
8500 0.00039833543
8600 0.0003982313
8700 0.00039812792
8800 0.00039802503
8900 0.00039792218
9000 0.00039781947
9100 0.0003977169
9200 0.0003976149
9300 0.00039751336
9400 0.0003974113
9500 0.0003973098
9600 0.0003972085
9700 0.00039710721
9800 0.00039700675
9900 0.00039690587
复制代码

Nonlinear model training results.png

Reference in this article

Reference for this article: blog.csdn.net/qq_19672707…

Guess you like

Origin juejin.im/post/7086260736120324133