模型训练过程小结(适合小白)

为了让自己更清楚模型的训练过程,在这里对训练过程进行一下记录,欢迎大佬补充与指正。

1、fit函数训练过程:

model.compile(loss="sparse_categorical_crossentropy",
              optimizer = keras.optimizers.SGD(0.001),
              metrics = ["accuracy"])

history = model.fit(x_train_scaled, y_train, epochs=10,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

如上是tf2.0比较常见的模型训练代码,在这种情况下:因为epochs=10,所以会遍历10次训练集,在每遍历一次训练集后会在验证集上进行验证,因为metrics中指明“我们还要关注accuracy",所以在每一次epoch中,会输出测试集上的loss和accuracy,与验证集上的val_loss和val_accuracy。

Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 4s 78us/sample - loss: 0.9102 - accuracy: 0.7004 - val_loss: 0.6062 - val_accuracy: 0.7928
Epoch 2/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.5768 - accuracy: 0.8011 - val_loss: 0.5134 - val_accuracy: 0.8244
Epoch 3/10
55000/55000 [==============================] - 3s 62us/sample - loss: 0.5108 - accuracy: 0.8220 - val_loss: 0.4736 - val_accuracy: 0.8384
Epoch 4/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4759 - accuracy: 0.8330 - val_loss: 0.4490 - val_accuracy: 0.8438
Epoch 5/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4527 - accuracy: 0.8412 - val_loss: 0.4328 - val_accuracy: 0.8494
Epoch 6/10
55000/55000 [==============================] - 4s 68us/sample - loss: 0.4357 - accuracy: 0.8479 - val_loss: 0.4200 - val_accuracy: 0.8528
Epoch 7/10
55000/55000 [==============================] - 4s 66us/sample - loss: 0.4219 - accuracy: 0.8518 - val_loss: 0.4082 - val_accuracy: 0.8586
Epoch 8/10
55000/55000 [==============================] - 4s 66us/sample - loss: 0.4108 - accuracy: 0.8547 - val_loss: 0.3997 - val_accuracy: 0.8614
Epoch 9/10
55000/55000 [==============================] - 4s 68us/sample - loss: 0.4012 - accuracy: 0.8586 - val_loss: 0.3963 - val_accuracy: 0.8614
Epoch 10/10
55000/55000 [==============================] - 4s 66us/sample - loss: 0.3925 - accuracy: 0.8607 - val_loss: 0.3881 - val_accuracy: 0.8630

如果最后要在测试集上进行验证,可以运行下面代码:

model.evaluate(x_test_scaled, y_test, verbose=0)
[0.34827180202007296, 0.8757]

 补充:在fit函数中也是以batch的形式遍历训练集的,默认batch_size=32,因此,在fit函数中做的事情可以总结为:1、以batch的形式遍历训练集, 然后统计训练集上的metric(其中包含自动求导,因此如果想要替换这部分,就需要实现这三部分)。2、一个epoch结束后,在验证集进行验证,统计验证集上的loss。

2、使用sklearn进行超参数搜索时的训练(cross_validation机制)

cross_validation机制: 训练集分成n份,n-1训练,最后一份验证.(默认n=4)
训练、验证、测试的关系是:每训练一次训练集(每经过一个epoch),就在验证集上做一次验证,全部训练完,最后在测试集上做测试。
在超参数搜索中,有了cross_validation机制,多了一份验证集,那它具体是怎样执行的呢?,先在前n-1份的数据上训练,每经过一个epoch,在第n份上做一个验证,当训练完100个epoch后(假设设置:epochs = 100),在x_valid_scaled(也就是验证集)在进行一次验证,在最后超参数搜索完之后再在全部训练集上用新的参数再训练一遍,每经过一个epoch,仍在验证集上进行验证。如果你感觉有点乱,看下面的训练日志,你就会清楚了。

下面是我截取的”有cross_validation机制”的训练过程的其中一个运行100次epoch的日志。它会在7740的训练数据上进行训练(这是11610的3/4),然后在剩下的1/4数据上进行验证,就是输出的val_loss,在运行100次epoch后(有early_stopping机制),在3870的验证集上进行验证。

Train on 7740 samples, validate on 3870 samples
Epoch 1/100
7740/7740 [==============================] - 1s 112us/sample - loss: 5.2672 - val_loss: 4.9883
Epoch 2/100
7740/7740 [==============================] - 1s 80us/sample - loss: 4.3754 - val_loss: 4.1726
Epoch 3/100
7740/7740 [==============================] - 1s 95us/sample - loss: 3.6575 - val_loss: 3.4975
Epoch 4/100
7740/7740 [==============================] - 1s 87us/sample - loss: 3.0589 - val_loss: 2.9331
Epoch 5/100
7740/7740 [==============================] - 1s 102us/sample - loss: 2.5663 - val_loss: 2.4767
Epoch 6/100
7740/7740 [==============================] - 1s 87us/sample - loss: 2.1773 - val_loss: 2.1219
Epoch 7/100
7740/7740 [==============================] - 1s 80us/sample - loss: 1.8800 - val_loss: 1.8536
Epoch 8/100
7740/7740 [==============================] - 1s 79us/sample - loss: 1.6559 - val_loss: 1.6512
Epoch 9/100
7740/7740 [==============================] - 1s 81us/sample - loss: 1.4865 - val_loss: 1.4947
Epoch 10/100
7740/7740 [==============================] - 1s 80us/sample - loss: 1.3585 - val_loss: 1.3753
Epoch 11/100
7740/7740 [==============================] - 1s 81us/sample - loss: 1.2567 - val_loss: 1.2781
Epoch 12/100
7740/7740 [==============================] - 1s 80us/sample - loss: 1.1739 - val_loss: 1.1986
Epoch 13/100
7740/7740 [==============================] - 1s 95us/sample - loss: 1.1043 - val_loss: 1.1311
Epoch 14/100
7740/7740 [==============================] - 1s 81us/sample - loss: 1.0435 - val_loss: 1.0722
Epoch 15/100
7740/7740 [==============================] - 1s 80us/sample - loss: 0.9901 - val_loss: 1.0206
Epoch 16/100
7740/7740 [==============================] - 1s 82us/sample - loss: 0.9431 - val_loss: 0.9755
Epoch 17/100
7740/7740 [==============================] - 1s 80us/sample - loss: 0.9022 - val_loss: 0.9364
Epoch 18/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.8660 - val_loss: 0.9022
Epoch 19/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.8341 - val_loss: 0.8723
Epoch 20/100
7740/7740 [==============================] - 1s 89us/sample - loss: 0.8067 - val_loss: 0.8463
Epoch 21/100
7740/7740 [==============================] - 1s 80us/sample - loss: 0.7832 - val_loss: 0.8239
Epoch 22/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.7632 - val_loss: 0.8043
Epoch 23/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.7458 - val_loss: 0.7873
Epoch 24/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.7307 - val_loss: 0.7726
Epoch 25/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.7175 - val_loss: 0.7598
Epoch 26/100
7740/7740 [==============================] - 1s 92us/sample - loss: 0.7060 - val_loss: 0.7486
Epoch 27/100
7740/7740 [==============================] - 1s 84us/sample - loss: 0.6958 - val_loss: 0.7387
Epoch 28/100
7740/7740 [==============================] - 1s 83us/sample - loss: 0.6866 - val_loss: 0.7300
Epoch 29/100
7740/7740 [==============================] - 1s 88us/sample - loss: 0.6786 - val_loss: 0.7223
Epoch 30/100
7740/7740 [==============================] - 1s 105us/sample - loss: 0.6716 - val_loss: 0.7155
Epoch 31/100
7740/7740 [==============================] - 1s 88us/sample - loss: 0.6653 - val_loss: 0.7093
Epoch 32/100
7740/7740 [==============================] - 1s 92us/sample - loss: 0.6597 - val_loss: 0.7037
Epoch 33/100
7740/7740 [==============================] - 1s 105us/sample - loss: 0.6547 - val_loss: 0.6987
Epoch 34/100
7740/7740 [==============================] - 1s 92us/sample - loss: 0.6503 - val_loss: 0.6942
Epoch 35/100
7740/7740 [==============================] - 1s 91us/sample - loss: 0.6463 - val_loss: 0.6900
Epoch 36/100
7740/7740 [==============================] - 1s 84us/sample - loss: 0.6425 - val_loss: 0.6860
Epoch 37/100
7740/7740 [==============================] - 1s 81us/sample - loss: 0.6391 - val_loss: 0.6823
Epoch 38/100
7740/7740 [==============================] - 1s 83us/sample - loss: 0.6359 - val_loss: 0.6788
Epoch 39/100
7740/7740 [==============================] - 1s 80us/sample - loss: 0.6329 - val_loss: 0.6756
Epoch 40/100
7740/7740 [==============================] - 1s 79us/sample - loss: 0.6301 - val_loss: 0.6725
Epoch 41/100
7740/7740 [==============================] - 1s 76us/sample - loss: 0.6274 - val_loss: 0.6696
Epoch 42/100
7740/7740 [==============================] - 1s 80us/sample - loss: 0.6248 - val_loss: 0.6667
Epoch 43/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.6224 - val_loss: 0.6640
Epoch 44/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.6201 - val_loss: 0.6614
Epoch 45/100
7740/7740 [==============================] - 1s 79us/sample - loss: 0.6178 - val_loss: 0.6588
Epoch 46/100
7740/7740 [==============================] - 1s 78us/sample - loss: 0.6157 - val_loss: 0.6563
Epoch 47/100
7740/7740 [==============================] - 1s 78us/sample - loss: 0.6136 - val_loss: 0.6540
Epoch 48/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.6115 - val_loss: 0.6516
Epoch 49/100
7740/7740 [==============================] - 1s 76us/sample - loss: 0.6095 - val_loss: 0.6494
Epoch 50/100
7740/7740 [==============================] - 1s 78us/sample - loss: 0.6076 - val_loss: 0.6472
Epoch 51/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.6057 - val_loss: 0.6450
Epoch 52/100
7740/7740 [==============================] - 1s 78us/sample - loss: 0.6038 - val_loss: 0.6429
Epoch 53/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.6020 - val_loss: 0.6409
Epoch 54/100
7740/7740 [==============================] - 1s 82us/sample - loss: 0.6002 - val_loss: 0.6388
Epoch 55/100
7740/7740 [==============================] - 1s 87us/sample - loss: 0.5985 - val_loss: 0.6369
Epoch 56/100
7740/7740 [==============================] - 1s 78us/sample - loss: 0.5968 - val_loss: 0.6349
Epoch 57/100
7740/7740 [==============================] - 1s 76us/sample - loss: 0.5951 - val_loss: 0.6330
Epoch 58/100
7740/7740 [==============================] - 1s 76us/sample - loss: 0.5934 - val_loss: 0.6312
Epoch 59/100
7740/7740 [==============================] - 1s 75us/sample - loss: 0.5918 - val_loss: 0.6293
Epoch 60/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.5902 - val_loss: 0.6275
Epoch 61/100
7740/7740 [==============================] - 1s 77us/sample - loss: 0.5887 - val_loss: 0.6258
3870/1 [==================================================================] - 0s 40us/sample - loss: 0.5191

这是全部超参数搜索结束后,从sample出的10个超参组合中的最好的一组,它的选择就是按照它在训练集(11610样本)内切分出来的验证集(3870样本)上的val_loss,然后在全部训练集(11610)上进行训练。

Train on 11610 samples, validate on 3870 samples
Epoch 1/100
11610/11610 [==============================] - 1s 89us/sample - loss: 0.6970 - val_loss: 0.5471
Epoch 2/100
11610/11610 [==============================] - 1s 70us/sample - loss: 0.4846 - val_loss: 0.4733
Epoch 3/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.4354 - val_loss: 0.4388
Epoch 4/100
11610/11610 [==============================] - 1s 72us/sample - loss: 0.4078 - val_loss: 0.4094
Epoch 5/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3932 - val_loss: 0.4081
Epoch 6/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3837 - val_loss: 0.3907
Epoch 7/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3764 - val_loss: 0.3780
Epoch 8/100
11610/11610 [==============================] - 1s 70us/sample - loss: 0.3675 - val_loss: 0.3833
Epoch 9/100
11610/11610 [==============================] - 1s 74us/sample - loss: 0.3629 - val_loss: 0.3767
Epoch 10/100
11610/11610 [==============================] - 1s 73us/sample - loss: 0.3576 - val_loss: 0.3696
Epoch 11/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3520 - val_loss: 0.3596
Epoch 12/100
11610/11610 [==============================] - 1s 78us/sample - loss: 0.3478 - val_loss: 0.3549
Epoch 13/100
11610/11610 [==============================] - 1s 81us/sample - loss: 0.3439 - val_loss: 0.3583
Epoch 14/100
11610/11610 [==============================] - 1s 80us/sample - loss: 0.3411 - val_loss: 0.3544
Epoch 15/100
11610/11610 [==============================] - 1s 81us/sample - loss: 0.3374 - val_loss: 0.3472
Epoch 16/100
11610/11610 [==============================] - 1s 76us/sample - loss: 0.3335 - val_loss: 0.3424
Epoch 17/100
11610/11610 [==============================] - 1s 76us/sample - loss: 0.3318 - val_loss: 0.3452
Epoch 18/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3287 - val_loss: 0.3410
Epoch 19/100
11610/11610 [==============================] - 1s 69us/sample - loss: 0.3278 - val_loss: 0.3427
Epoch 20/100
11610/11610 [==============================] - 1s 71us/sample - loss: 0.3249 - val_loss: 0.3404
原创文章 46 获赞 49 访问量 2182

猜你喜欢

转载自blog.csdn.net/qq_41660119/article/details/105831752