CNN BASIC: monitor and control the training process of magic --Keras callback function and TensorBoard

Training model, a lot of things can not begin to predict. For example, before we have to find out how many iterations to get the best round loss verification may first iteration 100, the iteration is complete draw operating results, found that in the middle began to over-fitting, so he resumed training.

Many similar cases, so we want to train dynamic real-time monitoring and timely to take certain measures in accordance with the training model. Keras callback function and tf of TensorBoard were born to do.

Keras callback function

Callback function (callbacks) is a model of the object passed in a call to fit, it will be different points of time in the training process is called model. It can access all available data on the status and performance of the model, you can also take action: interrupt training, save the model, load a different set of weights or change the model state. That is, before the process of training model, we do not know the real-time status model, therefore in order to better monitor and control the training process model, we have sent a correspondent - callback function, it can record according to the situation, feedback or to take measures. We are familiar with the training progress bar and fit the return of history is the callback function, but Talia too common, so be carry out alone.

fit_generator fit and function provides callbacks interface. Commonly used callback function are:

  • ModelCheckpoint (save the current model after each round);
  • EarlyStopping (if not improve monitoring parameters interrupted training);
  • LearningRateScheduler (dynamic learning rate adjustments in the training process);
  • ReduceLROnPlateau (if not improve the performance of authentication, it can decrease the learning rate, out of a local minimum);
  • CSVLogger (the results of each epoch write CSV file).
  • Other callback function, you can also write your own if necessary.

Application examples:

from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

#fit提供callbacks接口,接收一个回调函数列表,可将任意个回调函数传入模型中
callback_lists = []

callback_lists.append(EarlyStopping(monitor = 'acc', #监控模型的验证精度
                                    patience = 1)) #如果精度在多于一轮的时间(即两轮)内不再改善,就中断训练

callback_lists.append(ModelCheckpoint(filepath = 'my_model.h5', #目标文件的保存路径
                                      monitor = 'val_loss',  #监控验证损失
                                      save_best_only = True)) #只保存最佳模型

callback_lists.append(ReduceLROnPlateau(monitor = 'val_loss',  #监控模型的验证损失
                                      factor = 0.1,  #触发时将学习率乘以系数0.1
                                      patience = 10) #若验证损失在10轮内都没有改善,则触发该回调函数

#由于回调函数要监控验证损失和验证精度,所以在调用fit时需要传入validation_data
model.fit(x, y, epochs = 10, batch_size = 32,
         callbacks = callbacks_list,
         validation_data = (x_val, y_val))

TensorBoard: real-time visualization tool

TensorBoard is built in TensorFlow-based visualization tool browser will automatically install when you install this tool TensorFlow. In simple terms, it is the training process data written to the file, and then use the browser to view the tool. In Keras, which is also packaged as a callback function.

Examples are as follows:

#引入Tensorboard
from keras.callbacks import TensorBoard

#定义回调函数列表,现在只放一个简单的TensorBoard
log_path = './logs' #指定TensorBoard读取的文件路径,可以新建一个
callback_lists = [TensorBoard(log_dir=log_path, histogram_freq=1)]

#模型调用fit时,通过回调函数接口传入
model.fit(...inputs and parameters..., callbacks=callback_lists)

In the process of training in order to visualize the indicators, so you need to start TensorBoard in the terminal.

Open the terminal in two ways: one is a terminal system comes cmd; the other terminal is Anaconda Prompt. Select Open in what terminal, according to the terminal mode was used during installation tensorflow. I tried cmd, always wrong, but will be able to start properly in Anaconda Prompt terminal.

Start: enter in a terminal tensorboard --logdir = C: \ Users ... \ logs (own path file), returns a row of information, including an http URL. This address is generally not changed, input prompts http address in the browser to view the training process and the relevant state model, as shown below.

Reference

Books: Python deep learning

Guess you like

Origin www.cnblogs.com/inchbyinch/p/11986629.html