30天干掉tensorflow2.0-day10 高阶API示范

3-3,高阶API示范

下面的范例使用TensorFlow的高阶API实现线性回归模型。

TensorFlow的高阶API主要为tf.keras.models提供的模型的类接口。

使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。

此处分别演示使用Sequential按层顺序构建模型以及继承Model基类构建自定义模型。

一,使用Sequential按层顺序构建模型【面向新手】

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers

#样本数量
n = 800

# 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10) 
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)

Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)  # @表示矩阵乘法,增加正态扰动
tf.keras.backend.clear_session()

linear = models.Sequential()
linear.add(layers.Dense(1,input_shape =(2,)))
linear.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1)                 3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
### 使用fit方法进行训练

linear.compile(optimizer="adam",loss="mse",metrics=["mae"])
linear.fit(X,Y,batch_size = 20,epochs = 200)  

tf.print("w = ",linear.layers[0].kernel)
tf.print("b = ",linear.layers[0].bias)

Train on 800 samples
Epoch 1/200
800/800 [==============================] - 0s 431us/sample - loss: 59.5923 - mae: 6.4031
Epoch 2/200
800/800 [==============================] - 0s 74us/sample - loss: 55.7390 - mae: 6.1912
Epoch 3/200
800/800 [==============================] - 0s 75us/sample - loss: 52.1282 - mae: 5.9840
Epoch 4/200
800/800 [==============================] - 0s 77us/sample - loss: 48.7683 - mae: 5.7867
Epoch 5/200
800/800 [==============================] - 0s 79us/sample - loss: 45.6057 - mae: 5.5918
Epoch 6/200
800/800 [==============================] - 0s 80us/sample - loss: 42.6979 - mae: 5.4078
Epoch 7/200
800/800 [==============================] - 0s 72us/sample - loss: 39.9568 - mae: 5.2268
Epoch 8/200
800/800 [==============================] - 0s 77us/sample - loss: 37.4595 - mae: 5.0529
Epoch 9/200
800/800 [==============================] - 0s 75us/sample - loss: 35.1030 - mae: 4.8895
Epoch 10/200
800/800 [==============================] - 0s 85us/sample - loss: 32.9400 - mae: 4.7275
Epoch 11/200
800/800 [==============================] - 0s 82us/sample - loss: 30.9267 - mae: 4.5742
Epoch 12/200
800/800 [==============================] - 0s 102us/sample - loss: 29.0392 - mae: 4.4292
Epoch 13/200
800/800 [==============================] - 0s 117us/sample - loss: 27.3177 - mae: 4.2892
Epoch 14/200
800/800 [==============================] - 0s 121us/sample - loss: 25.6875 - mae: 4.1525
Epoch 15/200
800/800 [==============================] - 0s 97us/sample - loss: 24.1912 - mae: 4.0236
Epoch 16/200
800/800 [==============================] - 0s 76us/sample - loss: 22.7815 - mae: 3.8984
Epoch 17/200
800/800 [==============================] - 0s 66us/sample - loss: 21.4892 - mae: 3.7816
Epoch 18/200
800/800 [==============================] - 0s 86us/sample - loss: 20.2663 - mae: 3.6667
Epoch 19/200
800/800 [==============================] - 0s 79us/sample - loss: 19.1326 - mae: 3.5605
Epoch 20/200
800/800 [==============================] - 0s 77us/sample - loss: 18.0875 - mae: 3.4573
Epoch 21/200
800/800 [==============================] - 0s 77us/sample - loss: 17.0995 - mae: 3.3582
Epoch 22/200
800/800 [==============================] - 0s 97us/sample - loss: 16.1917 - mae: 3.2653
Epoch 23/200
800/800 [==============================] - 0s 90us/sample - loss: 15.3374 - mae: 3.1782
Epoch 24/200
800/800 [==============================] - 0s 109us/sample - loss: 14.5553 - mae: 3.0942
Epoch 25/200
800/800 [==============================] - 0s 135us/sample - loss: 13.8164 - mae: 3.0142
Epoch 26/200
800/800 [==============================] - 0s 117us/sample - loss: 13.1373 - mae: 2.9382
Epoch 27/200
800/800 [==============================] - 0s 95us/sample - loss: 12.5057 - mae: 2.8683
Epoch 28/200
800/800 [==============================] - 0s 72us/sample - loss: 11.9205 - mae: 2.8021
Epoch 29/200
800/800 [==============================] - 0s 77us/sample - loss: 11.3788 - mae: 2.7411
Epoch 30/200
800/800 [==============================] - 0s 101us/sample - loss: 10.8739 - mae: 2.6822
Epoch 31/200
800/800 [==============================] - 0s 114us/sample - loss: 10.4058 - mae: 2.6261
Epoch 32/200
800/800 [==============================] - 0s 107us/sample - loss: 9.9815 - mae: 2.5735
Epoch 33/200
800/800 [==============================] - 0s 86us/sample - loss: 9.5806 - mae: 2.5216
Epoch 34/200
800/800 [==============================] - 0s 107us/sample - loss: 9.2189 - mae: 2.4749
Epoch 35/200
800/800 [==============================] - 0s 124us/sample - loss: 8.8742 - mae: 2.4273
Epoch 36/200
800/800 [==============================] - 0s 102us/sample - loss: 8.5633 - mae: 2.3843
Epoch 37/200
800/800 [==============================] - 0s 86us/sample - loss: 8.2791 - mae: 2.3452
Epoch 38/200
800/800 [==============================] - 0s 76us/sample - loss: 8.0106 - mae: 2.3075
Epoch 39/200
800/800 [==============================] - 0s 75us/sample - loss: 7.7674 - mae: 2.2727
Epoch 40/200
800/800 [==============================] - 0s 65us/sample - loss: 7.5360 - mae: 2.2387
Epoch 41/200
800/800 [==============================] - 0s 77us/sample - loss: 7.3313 - mae: 2.2089
Epoch 42/200
800/800 [==============================] - 0s 86us/sample - loss: 7.1378 - mae: 2.1796
Epoch 43/200
800/800 [==============================] - 0s 76us/sample - loss: 6.9573 - mae: 2.1512
Epoch 44/200
800/800 [==============================] - 0s 66us/sample - loss: 6.7930 - mae: 2.1250
Epoch 45/200
800/800 [==============================] - 0s 77us/sample - loss: 6.6396 - mae: 2.1003
Epoch 46/200
800/800 [==============================] - 0s 77us/sample - loss: 6.4979 - mae: 2.0755
Epoch 47/200
800/800 [==============================] - 0s 85us/sample - loss: 6.3650 - mae: 2.0521
Epoch 48/200
800/800 [==============================] - 0s 77us/sample - loss: 6.2432 - mae: 2.0311
Epoch 49/200
800/800 [==============================] - 0s 81us/sample - loss: 6.1287 - mae: 2.0105
Epoch 50/200
800/800 [==============================] - 0s 105us/sample - loss: 6.0219 - mae: 1.9912
Epoch 51/200
800/800 [==============================] - 0s 100us/sample - loss: 5.9220 - mae: 1.9731
Epoch 52/200
800/800 [==============================] - 0s 106us/sample - loss: 5.8268 - mae: 1.9558
Epoch 53/200
800/800 [==============================] - 0s 106us/sample - loss: 5.7397 - mae: 1.9395
Epoch 54/200
800/800 [==============================] - 0s 80us/sample - loss: 5.6555 - mae: 1.9241
Epoch 55/200
800/800 [==============================] - 0s 71us/sample - loss: 5.5775 - mae: 1.9096
Epoch 56/200
800/800 [==============================] - 0s 75us/sample - loss: 5.5052 - mae: 1.8964
Epoch 57/200
800/800 [==============================] - 0s 102us/sample - loss: 5.4329 - mae: 1.8833
Epoch 58/200
800/800 [==============================] - 0s 117us/sample - loss: 5.3658 - mae: 1.8706
Epoch 59/200
800/800 [==============================] - 0s 89us/sample - loss: 5.3038 - mae: 1.8598
Epoch 60/200
800/800 [==============================] - 0s 80us/sample - loss: 5.2419 - mae: 1.8482
Epoch 61/200
800/800 [==============================] - 0s 65us/sample - loss: 5.1872 - mae: 1.8387
Epoch 62/200
800/800 [==============================] - 0s 74us/sample - loss: 5.1297 - mae: 1.8272
Epoch 63/200
800/800 [==============================] - 0s 64us/sample - loss: 5.0762 - mae: 1.8166
Epoch 64/200
800/800 [==============================] - ETA: 0s - loss: 4.8650 - mae: 1.782 - 0s 80us/sample - loss: 5.0266 - mae: 1.8070
Epoch 65/200
800/800 [==============================] - 0s 79us/sample - loss: 4.9788 - mae: 1.7977
Epoch 66/200
800/800 [==============================] - 0s 65us/sample - loss: 4.9320 - mae: 1.7884
Epoch 67/200
800/800 [==============================] - 0s 65us/sample - loss: 4.8881 - mae: 1.7801
Epoch 68/200
800/800 [==============================] - 0s 76us/sample - loss: 4.8468 - mae: 1.7719
Epoch 69/200
800/800 [==============================] - 0s 61us/sample - loss: 4.8070 - mae: 1.7634
Epoch 70/200
800/800 [==============================] - 0s 65us/sample - loss: 4.7661 - mae: 1.7556
Epoch 71/200
800/800 [==============================] - 0s 61us/sample - loss: 4.7309 - mae: 1.7486
Epoch 72/200
800/800 [==============================] - 0s 74us/sample - loss: 4.6955 - mae: 1.7411
Epoch 73/200
800/800 [==============================] - 0s 61us/sample - loss: 4.6603 - mae: 1.7344
Epoch 74/200
800/800 [==============================] - 0s 62us/sample - loss: 4.6278 - mae: 1.7281
Epoch 75/200
800/800 [==============================] - 0s 61us/sample - loss: 4.5997 - mae: 1.7228
Epoch 76/200
800/800 [==============================] - 0s 77us/sample - loss: 4.5658 - mae: 1.7159
Epoch 77/200
800/800 [==============================] - 0s 81us/sample - loss: 4.5412 - mae: 1.7107
Epoch 78/200
800/800 [==============================] - 0s 110us/sample - loss: 4.5112 - mae: 1.7050
Epoch 79/200
800/800 [==============================] - 0s 81us/sample - loss: 4.4853 - mae: 1.7000
Epoch 80/200
800/800 [==============================] - 0s 75us/sample - loss: 4.4591 - mae: 1.6946
Epoch 81/200
800/800 [==============================] - 0s 71us/sample - loss: 4.4371 - mae: 1.6896
Epoch 82/200
800/800 [==============================] - 0s 72us/sample - loss: 4.4143 - mae: 1.6850
Epoch 83/200
800/800 [==============================] - 0s 81us/sample - loss: 4.3931 - mae: 1.6808
Epoch 84/200
800/800 [==============================] - 0s 69us/sample - loss: 4.3719 - mae: 1.6768
Epoch 85/200
800/800 [==============================] - 0s 79us/sample - loss: 4.3543 - mae: 1.6733
Epoch 86/200
800/800 [==============================] - 0s 81us/sample - loss: 4.3356 - mae: 1.6693
Epoch 87/200
800/800 [==============================] - 0s 80us/sample - loss: 4.3165 - mae: 1.6652
Epoch 88/200
800/800 [==============================] - 0s 71us/sample - loss: 4.3008 - mae: 1.6621
Epoch 89/200
800/800 [==============================] - 0s 77us/sample - loss: 4.2871 - mae: 1.6596
Epoch 90/200
800/800 [==============================] - 0s 74us/sample - loss: 4.2695 - mae: 1.6559
Epoch 91/200
800/800 [==============================] - 0s 66us/sample - loss: 4.2568 - mae: 1.6535
Epoch 92/200
800/800 [==============================] - 0s 69us/sample - loss: 4.2432 - mae: 1.6504
Epoch 93/200
800/800 [==============================] - 0s 66us/sample - loss: 4.2312 - mae: 1.6480
Epoch 94/200
800/800 [==============================] - 0s 60us/sample - loss: 4.2181 - mae: 1.6454
Epoch 95/200
800/800 [==============================] - 0s 62us/sample - loss: 4.2089 - mae: 1.6435
Epoch 96/200
800/800 [==============================] - 0s 77us/sample - loss: 4.1972 - mae: 1.6408
Epoch 97/200
800/800 [==============================] - 0s 67us/sample - loss: 4.1877 - mae: 1.6386
Epoch 98/200
800/800 [==============================] - 0s 62us/sample - loss: 4.1779 - mae: 1.6371
Epoch 99/200
800/800 [==============================] - 0s 59us/sample - loss: 4.1692 - mae: 1.6352
Epoch 100/200
800/800 [==============================] - 0s 57us/sample - loss: 4.1605 - mae: 1.6334
Epoch 101/200
800/800 [==============================] - 0s 57us/sample - loss: 4.1561 - mae: 1.6325
Epoch 102/200
800/800 [==============================] - 0s 65us/sample - loss: 4.1485 - mae: 1.6310
Epoch 103/200
800/800 [==============================] - 0s 57us/sample - loss: 4.1386 - mae: 1.6288
Epoch 104/200
800/800 [==============================] - 0s 65us/sample - loss: 4.1340 - mae: 1.6283
Epoch 105/200
800/800 [==============================] - 0s 66us/sample - loss: 4.1295 - mae: 1.6275
Epoch 106/200
800/800 [==============================] - 0s 85us/sample - loss: 4.1221 - mae: 1.6257
Epoch 107/200
800/800 [==============================] - 0s 91us/sample - loss: 4.1183 - mae: 1.6250
Epoch 108/200
800/800 [==============================] - 0s 89us/sample - loss: 4.1132 - mae: 1.6238
Epoch 109/200
800/800 [==============================] - 0s 87us/sample - loss: 4.1092 - mae: 1.6228
Epoch 110/200
800/800 [==============================] - 0s 91us/sample - loss: 4.1075 - mae: 1.6225
Epoch 111/200
800/800 [==============================] - 0s 104us/sample - loss: 4.1040 - mae: 1.6218
Epoch 112/200
800/800 [==============================] - 0s 91us/sample - loss: 4.1001 - mae: 1.6207
Epoch 113/200
800/800 [==============================] - 0s 89us/sample - loss: 4.0968 - mae: 1.6204
Epoch 114/200
800/800 [==============================] - 0s 90us/sample - loss: 4.0919 - mae: 1.6197
Epoch 115/200
800/800 [==============================] - 0s 89us/sample - loss: 4.0897 - mae: 1.6190
Epoch 116/200
800/800 [==============================] - 0s 87us/sample - loss: 4.0868 - mae: 1.6183
Epoch 117/200
800/800 [==============================] - 0s 92us/sample - loss: 4.0852 - mae: 1.6182
Epoch 118/200
800/800 [==============================] - 0s 89us/sample - loss: 4.0842 - mae: 1.6180
Epoch 119/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0809 - mae: 1.6171
Epoch 120/200
800/800 [==============================] - 0s 67us/sample - loss: 4.0798 - mae: 1.6169
Epoch 121/200
800/800 [==============================] - 0s 67us/sample - loss: 4.0796 - mae: 1.6171
Epoch 122/200
800/800 [==============================] - 0s 106us/sample - loss: 4.0789 - mae: 1.6170
Epoch 123/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0753 - mae: 1.6160
Epoch 124/200
800/800 [==============================] - 0s 81us/sample - loss: 4.0757 - mae: 1.6158
Epoch 125/200
800/800 [==============================] - 0s 74us/sample - loss: 4.0731 - mae: 1.6157
Epoch 126/200
800/800 [==============================] - 0s 86us/sample - loss: 4.0744 - mae: 1.6160
Epoch 127/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0704 - mae: 1.6151
Epoch 128/200
800/800 [==============================] - 0s 86us/sample - loss: 4.0734 - mae: 1.6161
Epoch 129/200
800/800 [==============================] - 0s 86us/sample - loss: 4.0705 - mae: 1.6153
Epoch 130/200
800/800 [==============================] - 0s 89us/sample - loss: 4.0698 - mae: 1.6151
Epoch 131/200
800/800 [==============================] - 0s 80us/sample - loss: 4.0690 - mae: 1.6150
Epoch 132/200
800/800 [==============================] - 0s 81us/sample - loss: 4.0682 - mae: 1.6147
Epoch 133/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0674 - mae: 1.6145
Epoch 134/200
800/800 [==============================] - 0s 95us/sample - loss: 4.0675 - mae: 1.6146
Epoch 135/200
800/800 [==============================] - 0s 96us/sample - loss: 4.0671 - mae: 1.6145
Epoch 136/200
800/800 [==============================] - 0s 95us/sample - loss: 4.0678 - mae: 1.6148
Epoch 137/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0680 - mae: 1.6143
Epoch 138/200
800/800 [==============================] - 0s 76us/sample - loss: 4.0671 - mae: 1.6145
Epoch 139/200
800/800 [==============================] - 0s 86us/sample - loss: 4.0663 - mae: 1.6141
Epoch 140/200
800/800 [==============================] - 0s 90us/sample - loss: 4.0671 - mae: 1.6145
Epoch 141/200
800/800 [==============================] - 0s 90us/sample - loss: 4.0663 - mae: 1.6143
Epoch 142/200
800/800 [==============================] - 0s 84us/sample - loss: 4.0701 - mae: 1.6151
Epoch 143/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0676 - mae: 1.6141
Epoch 144/200
800/800 [==============================] - 0s 92us/sample - loss: 4.0705 - mae: 1.6149
Epoch 145/200
800/800 [==============================] - 0s 91us/sample - loss: 4.0655 - mae: 1.6140
Epoch 146/200
800/800 [==============================] - 0s 87us/sample - loss: 4.0678 - mae: 1.6150
Epoch 147/200
800/800 [==============================] - 0s 77us/sample - loss: 4.0672 - mae: 1.6144
Epoch 148/200
800/800 [==============================] - 0s 76us/sample - loss: 4.0653 - mae: 1.6139
Epoch 149/200
800/800 [==============================] - 0s 81us/sample - loss: 4.0663 - mae: 1.6146
Epoch 150/200
800/800 [==============================] - 0s 90us/sample - loss: 4.0673 - mae: 1.6145
Epoch 151/200
800/800 [==============================] - 0s 89us/sample - loss: 4.0660 - mae: 1.6143
Epoch 152/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0660 - mae: 1.6141
Epoch 153/200
800/800 [==============================] - 0s 84us/sample - loss: 4.0663 - mae: 1.6140
Epoch 154/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0643 - mae: 1.6137
Epoch 155/200
800/800 [==============================] - 0s 84us/sample - loss: 4.0714 - mae: 1.6148
Epoch 156/200
800/800 [==============================] - 0s 91us/sample - loss: 4.0660 - mae: 1.6139
Epoch 157/200
800/800 [==============================] - 0s 87us/sample - loss: 4.0648 - mae: 1.6138
Epoch 158/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0661 - mae: 1.6147
Epoch 159/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0663 - mae: 1.6147
Epoch 160/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0656 - mae: 1.6140
Epoch 161/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0662 - mae: 1.6141
Epoch 162/200
800/800 [==============================] - 0s 95us/sample - loss: 4.0646 - mae: 1.6138
Epoch 163/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0677 - mae: 1.6152
Epoch 164/200
800/800 [==============================] - 0s 64us/sample - loss: 4.0661 - mae: 1.6139
Epoch 165/200
800/800 [==============================] - 0s 62us/sample - loss: 4.0660 - mae: 1.6147
Epoch 166/200
800/800 [==============================] - 0s 66us/sample - loss: 4.0659 - mae: 1.6138
Epoch 167/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0671 - mae: 1.6142
Epoch 168/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0648 - mae: 1.6139
Epoch 169/200
800/800 [==============================] - 0s 85us/sample - loss: 4.0667 - mae: 1.6140
Epoch 170/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0666 - mae: 1.6141
Epoch 171/200
800/800 [==============================] - 0s 65us/sample - loss: 4.0653 - mae: 1.6141
Epoch 172/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0668 - mae: 1.6145
Epoch 173/200
800/800 [==============================] - 0s 74us/sample - loss: 4.0662 - mae: 1.6138
Epoch 174/200
800/800 [==============================] - 0s 79us/sample - loss: 4.0656 - mae: 1.6141
Epoch 175/200
800/800 [==============================] - 0s 92us/sample - loss: 4.0657 - mae: 1.6141
Epoch 176/200
800/800 [==============================] - 0s 86us/sample - loss: 4.0648 - mae: 1.6137
Epoch 177/200
800/800 [==============================] - 0s 77us/sample - loss: 4.0656 - mae: 1.6139
Epoch 178/200
800/800 [==============================] - 0s 72us/sample - loss: 4.0666 - mae: 1.6144
Epoch 179/200
800/800 [==============================] - 0s 77us/sample - loss: 4.0652 - mae: 1.6138
Epoch 180/200
800/800 [==============================] - 0s 75us/sample - loss: 4.0666 - mae: 1.6141
Epoch 181/200
800/800 [==============================] - 0s 82us/sample - loss: 4.0656 - mae: 1.6142
Epoch 182/200
800/800 [==============================] - 0s 77us/sample - loss: 4.0645 - mae: 1.6138
Epoch 183/200
800/800 [==============================] - 0s 65us/sample - loss: 4.0647 - mae: 1.6137
Epoch 184/200
800/800 [==============================] - 0s 80us/sample - loss: 4.0657 - mae: 1.6141
Epoch 185/200
800/800 [==============================] - 0s 64us/sample - loss: 4.0653 - mae: 1.6140
Epoch 186/200
800/800 [==============================] - 0s 71us/sample - loss: 4.0646 - mae: 1.6138
Epoch 187/200
800/800 [==============================] - 0s 60us/sample - loss: 4.0651 - mae: 1.6141
Epoch 188/200
800/800 [==============================] - 0s 65us/sample - loss: 4.0642 - mae: 1.6137
Epoch 189/200
800/800 [==============================] - 0s 60us/sample - loss: 4.0664 - mae: 1.6139
Epoch 190/200
800/800 [==============================] - 0s 67us/sample - loss: 4.0644 - mae: 1.6139
Epoch 191/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0665 - mae: 1.6143
Epoch 192/200
800/800 [==============================] - 0s 62us/sample - loss: 4.0645 - mae: 1.6137
Epoch 193/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0669 - mae: 1.6144
Epoch 194/200
800/800 [==============================] - 0s 64us/sample - loss: 4.0648 - mae: 1.6137
Epoch 195/200
800/800 [==============================] - 0s 62us/sample - loss: 4.0650 - mae: 1.6142
Epoch 196/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0669 - mae: 1.6145
Epoch 197/200
800/800 [==============================] - 0s 60us/sample - loss: 4.0666 - mae: 1.6138
Epoch 198/200
800/800 [==============================] - 0s 62us/sample - loss: 4.0643 - mae: 1.6136
Epoch 199/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0663 - mae: 1.6141
Epoch 200/200
800/800 [==============================] - 0s 61us/sample - loss: 4.0649 - mae: 1.6138
w =  [[1.99322605]
 [-1.00142694]]
b =  [2.89923096]

二,继承Model基类构建自定义模型【面向专家】

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers,losses,metrics


#打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
    
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
    
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)
#样本数量
n = 800

# 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10) 
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)

Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)  # @表示矩阵乘法,增加正态扰动

ds_train = tf.data.Dataset.from_tensor_slices((X[0:n*3//4,:],Y[0:n*3//4,:])) \
     .shuffle(buffer_size = 1000).batch(20) \
     .prefetch(tf.data.experimental.AUTOTUNE) \
     .cache()

ds_valid = tf.data.Dataset.from_tensor_slices((X[n*3//4:,:],Y[n*3//4:,:])) \
     .shuffle(buffer_size = 1000).batch(20) \
     .prefetch(tf.data.experimental.AUTOTUNE) \
     .cache()

tf.keras.backend.clear_session()

class MyModel(models.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        
    def build(self,input_shape):
        self.dense1 = layers.Dense(1)   
        super(MyModel,self).build(input_shape)
    
    def call(self, x):
        y = self.dense1(x)
        return(y)

model = MyModel()
model.build(input_shape =(None,2))
model.summary()
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
### 自定义训练循环(专家教程)


optimizer = optimizers.Adam()
loss_func = losses.MeanSquaredError()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_metric = tf.keras.metrics.MeanAbsoluteError(name='train_mae')

valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_metric = tf.keras.metrics.MeanAbsoluteError(name='valid_mae')


@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)

@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
    

@tf.function
def train_model(model,ds_train,ds_valid,epochs):
    for epoch in tf.range(1,epochs+1):
        for features, labels in ds_train:
            train_step(model,features,labels)

        for features, labels in ds_valid:
            valid_step(model,features,labels)

        logs = 'Epoch={},Loss:{},MAE:{},Valid Loss:{},Valid MAE:{}'
        
        if  epoch%100 ==0:
            printbar()
            tf.print(tf.strings.format(logs,
            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
            tf.print("w=",model.layers[0].kernel)
            tf.print("b=",model.layers[0].bias)
            tf.print("")
        
        train_loss.reset_states()
        valid_loss.reset_states()
        train_metric.reset_states()
        valid_metric.reset_states()

train_model(model,ds_train,ds_valid,400)

================================================================================16:27:08
Epoch=100,Loss:104.5952,MAE:8.19842052,Valid Loss:118.657539,Valid MAE:8.79173183
w= [[1.36984992]
 [-0.972368121]]
b= [1.38565648]

================================================================================16:27:15
Epoch=200,Loss:55.5088463,MAE:5.11994123,Valid Loss:62.7634163,Valid MAE:5.44431162
w= [[1.98663604]
 [-0.996249795]]
b= [2.77318311]

================================================================================16:27:21
Epoch=300,Loss:38.301754,MAE:3.95671964,Valid Loss:43.0225525,Valid MAE:4.14339161
w= [[1.99494064]
 [-0.997208714]]
b= [2.97018743]

================================================================================16:27:28
Epoch=400,Loss:29.7181396,MAE:3.37643766,Valid Loss:33.1717262,Valid MAE:3.49341393
w= [[1.99479365]
 [-0.997226357]]
b= [2.97108555]

原创文章 58 获赞 7 访问量 6205

猜你喜欢

转载自blog.csdn.net/Elenstone/article/details/105405945
今日推荐