TensorFlow2 构造机器学习模型的步骤

TensorFlow2 构造机器学习模型的步骤

  1. 输入训练数据
  2. 定义训练参数
  3. 构建目标模型
  4. 定义损失函数
  5. 选择优化器及定义训练操作
  6. 循环迭代式优化超参数
import warnings
import pandas as pd
import tensorflow as tf
warnings.filterwarnings('ignore')

TRUE_W = 3.0
TRUE_b = 2.0
num = 1000

# 随机输入 X
inputs = tf.random.normal(shape=[num])
# 随机噪音
noise = tf.random.normal(shape=[num])
# 构造数据 Y
outputs = TRUE_W * inputs + TRUE_b + noise

# 数据情况
data = pd.DataFrame(data={'x':inputs, 'y':outputs})
print(data.shape)
data.head()

x y
0 1.229162 6.171450
1 -0.977991 0.969046
2 -2.086571 -2.965389
3 0.230531 1.701138
4 0.737569 5.143273
# y=wx+b
class Model(object):
    def __init__(self):
        # 初始化变量
        self.W = tf.Variable(5.0)
        self.b = tf.Variable(0.0)

    def __call__(self, x):
        # x 只有一维数据情况
        return self.W * x + self.b
# 测试
model = Model()
print(model([2,5]))

import matplotlib.pyplot as plt
%matplotlib inline

# 没有进行优化的效果图,此时的 w、b 是初始自定义的值
plt.scatter(inputs, outputs, c='b')  # 原始
plt.scatter(inputs, model(inputs), c='r')  # 预测
loss_value = loss(outputs, model(inputs))
print('Out loss value(MSE):{lossV}'.format(lossV=loss_value))

# 进行优化,更新 w、b
# w:w-梯度*学习率
# 其中梯度是对loss函数求导所得到的值
def train(model, inputs, outputs, learning_rate=0.1):
    with tf.GradientTape() as t:
        lossV = loss(outputs, model(inputs))
        # 求解梯度值
        dW, db = t.gradient(lossV, [model.W, model.b])
        # 更新参数 W、b 系数值,w:w-梯度*学习率
        model.W.assign_sub(dW * learning_rate)
        model.b.assign_sub(db * learning_rate)
    return model.W.numpy(), model.b.numpy()

best_W, best_b = 0, 0
for epoch in range(50):
    lossV = loss(outputs, model(inputs))
    WV, bV = train(model, inputs, outputs, learning_rate=0.1)
    print('epoch {i}:W_true = {WT}, W_pred = {WV}, b_true = {bT}, b_pred = {bV}, L2损失 = {lossV}'
          .format(i=epoch+1, WT=3, WV=WV, bT=2, bV=bV, lossV=lossV))
    best_W, best_b = WV, bV

epoch 1:W_true = 3, W_pred = 4.581630229949951, b_true = 2, b_pred = 0.392181396484375, L2损失 = 9.104086875915527
epoch 2:W_true = 3, W_pred = 4.24971342086792, b_true = 2, b_pred = 0.7063625454902649, L2损失 = 6.1494903564453125
epoch 3:W_true = 3, W_pred = 3.986381769180298, b_true = 2, b_pred = 0.9580533504486084, L2损失 = 4.272693157196045
epoch 4:W_true = 3, W_pred = 3.7774605751037598, b_true = 2, b_pred = 1.1596803665161133, L2损失 = 3.080418348312378
epoch 5:W_true = 3, W_pred = 3.6117053031921387, b_true = 2, b_pred = 1.3211997747421265, L2损失 = 2.3229281902313232
epoch 6:W_true = 3, W_pred = 3.480195999145508, b_true = 2, b_pred = 1.4505879878997803, L2损失 = 1.8416262865066528
epoch 7:W_true = 3, W_pred = 3.3758556842803955, b_true = 2, b_pred = 1.554235577583313, L2损失 = 1.5357838869094849
epoch 8:W_true = 3, W_pred = 3.2930703163146973, b_true = 2, b_pred = 1.6372624635696411, L2损失 = 1.3414183855056763
epoch 9:W_true = 3, W_pred = 3.227386474609375, b_true = 2, b_pred = 1.7037701606750488, L2损失 = 1.217886209487915
epoch 10:W_true = 3, W_pred = 3.1752705574035645, b_true = 2, b_pred = 1.757044792175293, L2损失 = 1.1393660306930542
epoch 11:W_true = 3, W_pred = 3.1339194774627686, b_true = 2, b_pred = 1.7997188568115234, L2损失 = 1.0894523859024048
epoch 12:W_true = 3, W_pred = 3.101109027862549, b_true = 2, b_pred = 1.8339011669158936, L2损失 = 1.0577201843261719
epoch 13:W_true = 3, W_pred = 3.075075149536133, b_true = 2, b_pred = 1.861281156539917, L2损失 = 1.0375452041625977
epoch 14:W_true = 3, W_pred = 3.054417848587036, b_true = 2, b_pred = 1.8832123279571533, L2损失 = 1.0247166156768799
epoch 15:W_true = 3, W_pred = 3.0380265712738037, b_true = 2, b_pred = 1.9007787704467773, L2损失 = 1.0165590047836304
epoch 16:W_true = 3, W_pred = 3.025020122528076, b_true = 2, b_pred = 1.914849042892456, L2损失 = 1.0113706588745117
epoch 17:W_true = 3, W_pred = 3.0146994590759277, b_true = 2, b_pred = 1.9261188507080078, L2损失 = 1.0080710649490356
epoch 18:W_true = 3, W_pred = 3.006509780883789, b_true = 2, b_pred = 1.935145378112793, L2损失 = 1.005972146987915
epoch 19:W_true = 3, W_pred = 3.0000109672546387, b_true = 2, b_pred = 1.9423751831054688, L2损失 = 1.0046370029449463
epoch 20:W_true = 3, W_pred = 2.994853973388672, b_true = 2, b_pred = 1.948165774345398, L2损失 = 1.003787636756897
epoch 21:W_true = 3, W_pred = 2.9907617568969727, b_true = 2, b_pred = 1.952803611755371, L2损失 = 1.0032471418380737
epoch 22:W_true = 3, W_pred = 2.9875142574310303, b_true = 2, b_pred = 1.9565181732177734, L2损失 = 1.0029033422470093
epoch 23:W_true = 3, W_pred = 2.9849371910095215, b_true = 2, b_pred = 1.9594931602478027, L2损失 = 1.0026843547821045
epoch 24:W_true = 3, W_pred = 2.9828920364379883, b_true = 2, b_pred = 1.9618759155273438, L2损失 = 1.0025452375411987
epoch 25:W_true = 3, W_pred = 2.981268882751465, b_true = 2, b_pred = 1.9637842178344727, L2损失 = 1.002456545829773
epoch 26:W_true = 3, W_pred = 2.979980945587158, b_true = 2, b_pred = 1.9653124809265137, L2损失 = 1.0024001598358154
epoch 27:W_true = 3, W_pred = 2.9789586067199707, b_true = 2, b_pred = 1.966536521911621, L2损失 = 1.0023642778396606
epoch 28:W_true = 3, W_pred = 2.978147268295288, b_true = 2, b_pred = 1.9675167798995972, L2损失 = 1.0023412704467773
epoch 29:W_true = 3, W_pred = 2.977503538131714, b_true = 2, b_pred = 1.968301773071289, L2损失 = 1.0023267269134521
epoch 30:W_true = 3, W_pred = 2.976992607116699, b_true = 2, b_pred = 1.9689304828643799, L2损失 = 1.0023175477981567
epoch 31:W_true = 3, W_pred = 2.9765870571136475, b_true = 2, b_pred = 1.9694340229034424, L2损失 = 1.0023115873336792
epoch 32:W_true = 3, W_pred = 2.9762651920318604, b_true = 2, b_pred = 1.9698373079299927, L2损失 = 1.0023078918457031
epoch 33:W_true = 3, W_pred = 2.9760096073150635, b_true = 2, b_pred = 1.9701602458953857, L2损失 = 1.002305507659912
epoch 34:W_true = 3, W_pred = 2.975806951522827, b_true = 2, b_pred = 1.9704188108444214, L2损失 = 1.002303957939148
epoch 35:W_true = 3, W_pred = 2.9756460189819336, b_true = 2, b_pred = 1.970625877380371, L2损失 = 1.0023030042648315
epoch 36:W_true = 3, W_pred = 2.975518226623535, b_true = 2, b_pred = 1.9707916975021362, L2损失 = 1.0023024082183838
epoch 37:W_true = 3, W_pred = 2.975416898727417, b_true = 2, b_pred = 1.9709244966506958, L2损失 = 1.0023019313812256
epoch 38:W_true = 3, W_pred = 2.9753363132476807, b_true = 2, b_pred = 1.971030831336975, L2损失 = 1.002301812171936
epoch 39:W_true = 3, W_pred = 2.9752724170684814, b_true = 2, b_pred = 1.971116065979004, L2损失 = 1.0023016929626465
epoch 40:W_true = 3, W_pred = 2.975221633911133, b_true = 2, b_pred = 1.971184253692627, L2损失 = 1.0023013353347778
epoch 41:W_true = 3, W_pred = 2.9751813411712646, b_true = 2, b_pred = 1.9712388515472412, L2损失 = 1.0023014545440674
epoch 42:W_true = 3, W_pred = 2.975149393081665, b_true = 2, b_pred = 1.9712826013565063, L2损失 = 1.0023013353347778
epoch 43:W_true = 3, W_pred = 2.9751241207122803, b_true = 2, b_pred = 1.9713176488876343, L2损失 = 1.0023013353347778
epoch 44:W_true = 3, W_pred = 2.9751040935516357, b_true = 2, b_pred = 1.9713456630706787, L2损失 = 1.0023012161254883
epoch 45:W_true = 3, W_pred = 2.975088119506836, b_true = 2, b_pred = 1.9713680744171143, L2损失 = 1.0023013353347778
epoch 46:W_true = 3, W_pred = 2.9750754833221436, b_true = 2, b_pred = 1.9713860750198364, L2损失 = 1.0023013353347778
epoch 47:W_true = 3, W_pred = 2.9750654697418213, b_true = 2, b_pred = 1.971400499343872, L2损失 = 1.0023014545440674
epoch 48:W_true = 3, W_pred = 2.975057363510132, b_true = 2, b_pred = 1.9714120626449585, L2损失 = 1.0023012161254883
epoch 49:W_true = 3, W_pred = 2.975050926208496, b_true = 2, b_pred = 1.971421241760254, L2损失 = 1.0023013353347778
epoch 50:W_true = 3, W_pred = 2.975045919418335, b_true = 2, b_pred = 1.971428632736206, L2损失 = 1.0023012161254883

import matplotlib.pyplot as plt
%matplotlib inline

# 进行优化后的效果图,此时的 w、b 是最优化系数值
plt.scatter(inputs, outputs, c='b')  # 原始
plt.scatter(inputs, model(inputs), c='r')  # 预测
loss_value = loss(outputs, model(inputs))
print('Out loss value(MSE):{lossV}'.format(lossV=loss_value))
print('Out best_W = {best_W}, best_b = {best_b}'.format(best_W=best_W, best_b=best_b))

猜你喜欢

转载自blog.csdn.net/weixin_42452716/article/details/129083651