【TensorFlow2.x系列第2篇】TensorFlow2.0-构造机器学习模型的常规步骤

TensorFlow构造机器学习模型的常规步骤

  1. 输入训练数据
  2. 定义训练参数
  3. 构建目标模型
  4. 定义损失函数
  5. 选择优化器及定义训练操作
  6. 循环迭代式优化超参数

1.输入数据

import pandas as pd
import tensorflow as tf

TRUE_W = 3.0
TRUE_b = 2.0
num = 1000

# 随机输入 X
inputs = tf.random.normal(shape=[num])
# 随机噪音
noise = tf.random.normal(shape=[num])
# 构造数据 Y
outputs = TRUE_W * inputs + TRUE_b + noise

# 数据情况
data = pd.DataFrame(data={'x':inputs, 'y':outputs})
print(data.shape)
data.head()
(1000, 2)
x y
0 0.161683 2.967363
1 0.381507 1.185349
2 -0.775100 -0.544815
3 -1.010228 -0.546506
4 -1.888162 -3.443309

2.定义训练参数

比如定义初始超参数 W、b、learning_rate 等参数

3.构造目标模型

# y=wx+b
class Model(object):
    def __init__(self):
        # 初始化变量
        self.W = tf.Variable(5.0)
        self.b = tf.Variable(0.0)
    
    def __call__(self, x):
        # x 只有一维数据情况
        return self.W * x + self.b
# 测试
model = Model()
print(model([2,5]))
tf.Tensor([10. 25.], shape=(2,), dtype=float32)

4.定义损失函数

def loss(y_test, y_pred):
    # L2 损失
    return tf.reduce_mean(tf.square(y_pred-y_test))

未优化的拟合效果

import matplotlib.pyplot as plt
%matplotlib inline

# 没有进行优化的效果图,此时的 w、b 是初始自定义的值
plt.scatter(inputs, outputs, c='b')  # 原始
plt.scatter(inputs, model(inputs), c='r')  # 预测
loss_value = loss(outputs, model(inputs))
print('Out loss value(MSE):{lossV}'.format(lossV=loss_value))

Out loss value(MSE):8.654911041259766

5.定义优化操作

# 进行优化,更新 w、b
# w:w-梯度*学习率
# 其中梯度是对loss函数求导所得到的值
def train(model, inputs, outputs, learning_rate=0.1):
    with tf.GradientTape() as t:
        lossV = loss(outputs, model(inputs))
        # 求解梯度值
        dW, db = t.gradient(lossV, [model.W, model.b])
        # 更新参数 W、b 系数值,w:w-梯度*学习率
        model.W.assign_sub(dW * learning_rate)
        model.b.assign_sub(db * learning_rate)
    return model.W.numpy(), model.b.numpy()

6.循环迭代式得到最优超参数

best_W, best_b = 0, 0
for epoch in range(50):
    lossV = loss(outputs, model(inputs))
    WV, bV = train(model, inputs, outputs, learning_rate=0.1)
    print('''epoch {i}:W_true = {WT}, W_pred = {WV}, b_true = {bT}, b_pred = {bV}, L2损失 = {lossV}'''
          .format(i=epoch+1, WT=3, WV=WV, bT=2, bV=bV, lossV=lossV))
    best_W, best_b = WV, bV 
epoch 1:W_true = 3, W_pred = 4.606305122375488, b_true = 2, b_pred = 0.3836219608783722, L2损失 = 8.654911041259766
epoch 2:W_true = 3, W_pred = 4.291611194610596, b_true = 2, b_pred = 0.6929514408111572, L2损失 = 5.931309223175049
epoch 3:W_true = 3, W_pred = 4.040049076080322, b_true = 2, b_pred = 0.9423588514328003, L2损失 = 4.176153182983398
epoch 4:W_true = 3, W_pred = 3.838940143585205, b_true = 2, b_pred = 1.143438696861267, L2损失 = 3.045004367828369
epoch 5:W_true = 3, W_pred = 3.6781554222106934, b_true = 2, b_pred = 1.3055448532104492, L2损失 = 2.31595778465271
epoch 6:W_true = 3, W_pred = 3.5496010780334473, b_true = 2, b_pred = 1.4362229108810425, L2损失 = 1.8460404872894287
epoch 7:W_true = 3, W_pred = 3.446809768676758, b_true = 2, b_pred = 1.5415594577789307, L2损失 = 1.5431283712387085
epoch 8:W_true = 3, W_pred = 3.3646130561828613, b_true = 2, b_pred = 1.6264636516571045, L2损失 = 1.3478561639785767
epoch 9:W_true = 3, W_pred = 3.2988808155059814, b_true = 2, b_pred = 1.6948946714401245, L2损失 = 1.2219657897949219
epoch 10:W_true = 3, W_pred = 3.246311664581299, b_true = 2, b_pred = 1.7500455379486084, L2损失 = 1.1408005952835083
epoch 11:W_true = 3, W_pred = 3.2042670249938965, b_true = 2, b_pred = 1.794490933418274, L2损失 = 1.0884675979614258
epoch 12:W_true = 3, W_pred = 3.170637845993042, b_true = 2, b_pred = 1.8303070068359375, L2损失 = 1.0547230243682861
epoch 13:W_true = 3, W_pred = 3.143738031387329, b_true = 2, b_pred = 1.8591675758361816, L2损失 = 1.0329629182815552
epoch 14:W_true = 3, W_pred = 3.1222198009490967, b_true = 2, b_pred = 1.8824222087860107, L2损失 = 1.0189303159713745
epoch 15:W_true = 3, W_pred = 3.1050055027008057, b_true = 2, b_pred = 1.9011588096618652, L2损失 = 1.0098806619644165
epoch 16:W_true = 3, W_pred = 3.091233253479004, b_true = 2, b_pred = 1.9162544012069702, L2损失 = 1.0040440559387207
epoch 17:W_true = 3, W_pred = 3.080214262008667, b_true = 2, b_pred = 1.9284160137176514, L2损失 = 1.0002795457839966
epoch 18:W_true = 3, W_pred = 3.0713977813720703, b_true = 2, b_pred = 1.9382133483886719, L2损失 = 0.9978514909744263
epoch 19:W_true = 3, W_pred = 3.064342975616455, b_true = 2, b_pred = 1.946105718612671, L2损失 = 0.9962852001190186
epoch 20:W_true = 3, W_pred = 3.058697462081909, b_true = 2, b_pred = 1.952463150024414, L2损失 = 0.9952749013900757
epoch 21:W_true = 3, W_pred = 3.0541794300079346, b_true = 2, b_pred = 1.957584023475647, L2损失 = 0.9946232438087463
epoch 22:W_true = 3, W_pred = 3.0505635738372803, b_true = 2, b_pred = 1.9617085456848145, L2损失 = 0.9942026138305664
epoch 23:W_true = 3, W_pred = 3.0476696491241455, b_true = 2, b_pred = 1.965030550956726, L2损失 = 0.9939314723014832
epoch 24:W_true = 3, W_pred = 3.045353412628174, b_true = 2, b_pred = 1.9677059650421143, L2损失 = 0.9937564730644226
epoch 25:W_true = 3, W_pred = 3.043499231338501, b_true = 2, b_pred = 1.9698606729507446, L2损失 = 0.9936435222625732
epoch 26:W_true = 3, W_pred = 3.0420150756835938, b_true = 2, b_pred = 1.9715958833694458, L2损失 = 0.9935706853866577
epoch 27:W_true = 3, W_pred = 3.0408270359039307, b_true = 2, b_pred = 1.97299325466156, L2损失 = 0.9935235381126404
epoch 28:W_true = 3, W_pred = 3.0398757457733154, b_true = 2, b_pred = 1.9741184711456299, L2损失 = 0.9934933185577393
epoch 29:W_true = 3, W_pred = 3.039114236831665, b_true = 2, b_pred = 1.9750244617462158, L2损失 = 0.9934737086296082
epoch 30:W_true = 3, W_pred = 3.0385043621063232, b_true = 2, b_pred = 1.9757540225982666, L2损失 = 0.9934610724449158
epoch 31:W_true = 3, W_pred = 3.0380160808563232, b_true = 2, b_pred = 1.9763413667678833, L2损失 = 0.9934530258178711
epoch 32:W_true = 3, W_pred = 3.0376250743865967, b_true = 2, b_pred = 1.9768142700195312, L2损失 = 0.9934477210044861
epoch 33:W_true = 3, W_pred = 3.0373120307922363, b_true = 2, b_pred = 1.9771950244903564, L2损失 = 0.9934443235397339
epoch 34:W_true = 3, W_pred = 3.0370614528656006, b_true = 2, b_pred = 1.977501630783081, L2损失 = 0.9934421181678772
epoch 35:W_true = 3, W_pred = 3.036860704421997, b_true = 2, b_pred = 1.9777483940124512, L2损失 = 0.9934407472610474
epoch 36:W_true = 3, W_pred = 3.0367000102996826, b_true = 2, b_pred = 1.9779471158981323, L2损失 = 0.993439793586731
epoch 37:W_true = 3, W_pred = 3.0365712642669678, b_true = 2, b_pred = 1.97810697555542, L2损失 = 0.993439257144928
epoch 38:W_true = 3, W_pred = 3.036468029022217, b_true = 2, b_pred = 1.9782357215881348, L2损失 = 0.9934389591217041
epoch 39:W_true = 3, W_pred = 3.0363852977752686, b_true = 2, b_pred = 1.9783393144607544, L2损失 = 0.9934385418891907
epoch 40:W_true = 3, W_pred = 3.0363190174102783, b_true = 2, b_pred = 1.97842276096344, L2损失 = 0.9934384226799011
epoch 41:W_true = 3, W_pred = 3.0362660884857178, b_true = 2, b_pred = 1.978489875793457, L2损失 = 0.9934382438659668
epoch 42:W_true = 3, W_pred = 3.0362236499786377, b_true = 2, b_pred = 1.9785438776016235, L2损失 = 0.9934382438659668
epoch 43:W_true = 3, W_pred = 3.036189556121826, b_true = 2, b_pred = 1.9785873889923096, L2损失 = 0.993438184261322
epoch 44:W_true = 3, W_pred = 3.0361623764038086, b_true = 2, b_pred = 1.9786224365234375, L2损失 = 0.9934380054473877
epoch 45:W_true = 3, W_pred = 3.0361404418945312, b_true = 2, b_pred = 1.9786505699157715, L2损失 = 0.993438184261322
epoch 46:W_true = 3, W_pred = 3.036123037338257, b_true = 2, b_pred = 1.9786732196807861, L2损失 = 0.9934380650520325
epoch 47:W_true = 3, W_pred = 3.03610897064209, b_true = 2, b_pred = 1.9786914587020874, L2损失 = 0.993438184261322
epoch 48:W_true = 3, W_pred = 3.036097764968872, b_true = 2, b_pred = 1.9787061214447021, L2损失 = 0.993438184261322
epoch 49:W_true = 3, W_pred = 3.036088705062866, b_true = 2, b_pred = 1.9787179231643677, L2损失 = 0.9934381246566772
epoch 50:W_true = 3, W_pred = 3.036081552505493, b_true = 2, b_pred = 1.9787274599075317, L2损失 = 0.9934381246566772

优化后拟合效果

得到最优超参数 best_W, best_b 之后,预测的拟合效果图,请注意在本例中,W、b 是实时覆盖的,因此不需手动操作

import matplotlib.pyplot as plt
%matplotlib inline

# 进行优化后的效果图,此时的 w、b 是最优化系数值
plt.scatter(inputs, outputs, c='b')  # 原始
plt.scatter(inputs, model(inputs), c='r')  # 预测
loss_value = loss(outputs, model(inputs))
print('Out loss value(MSE):{lossV}'.format(lossV=loss_value))
print('Out best_W = {best_W}, best_b = {best_b}'.format(best_W=best_W, best_b=best_b))

Out loss value(MSE):0.9934381246566772
Out best_W = 3.036081552505493, best_b = 1.9787274599075317

参考资料:
https://github.com/czy36mengfei/tensorflow2_tutorials_chinese

发布了18 篇原创文章 · 获赞 10 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_41731978/article/details/103808979
今日推荐