tf2 training

train

Here is a simple example of using tensorflow training.

Model

Call the mobilenetv2 api of keras, remove its top layer, use its feature extraction part, and then reset the top layer part. Choosing to use pretrained weights can reduce our training difficulty.
One more detail is that the normalization is directly integrated on the network. The advantage of this is: on the one hand, there is no need to set it on the art, and on the other hand, there is no need to perform normalization when making the data set Preprocessing (If you forget to normalize, the model will not converge, and the pre-training weights will also require normalization, which is equivalent to reducing the need to troubleshoot when problems arise). Because normalization has been performed here, the parameters of the classification function on art need to add scale=1, offset=0.

def Mobilenet_v2(input_size,weights,Dropout_rate,Trainable,alpha = 0.35):
    base_model = keras.applications.MobileNetV2(
        input_shape=(input_size, input_size, 3),
        alpha=alpha,
        weights=weights,
        include_top=False
    )
    inputs = keras.Input(shape=(input_size, input_size,3))

    scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
    x = scale_layer(inputs)
    x = base_model(x, training=False)
    if Trainable:
        base_model.trainable =True
    else:
        base_model.trainable = False
        print("特征层已冻结")
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(Dropout_rate, name='Dropout')(x)
    outputs = keras.layers.Dense(15, activation='softmax')(x)
    model = keras.Model(inputs, outputs)
    return model

data set

The data set here directly uses the official data set, there is no independent verification set, and it is directly divided from the training set.
The enhancement uses the built-in tensorflow, and only the rotation, scaling and other enhancements are performed here.
This kind of enhancement is obviously not enough, and the method of enhancement will be discussed separately later.

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_root = "./train/"

train_generator = ImageDataGenerator(rotation_range=360,
                                     zoom_range  =0.2,
                                     horizontal_flip = True,
                                     validation_split =0.2
                                     )
train_dataset = train_generator.flow_from_directory(batch_size=batch_size,
                                                    directory=train_root,
                                                    shuffle=True,
                                                    target_size=(input_size,input_size),
                                                    subset='training')
valid_dataset = train_generator.flow_from_directory(batch_size=batch_size,
                                                   directory=train_root,
                                                    shuffle=True,
                                                    target_size=(input_size,input_size),
                                                    subset='validation')
print(train_dataset.class_indices)

It should be noted that the directory structure of the data set needs to be as follows, do not create a secondary directory divided into large and small categories like Zhufei:
insert image description here

call back

Here I set up 3 callbacks, among which learning to lower the callback is more important. The learning rate at the beginning may not be appropriate in the later stage of training, which will cause shocks. This situation can be avoided by reducing the learning rate. Usually, the accuracy of the model will be slightly improved by reducing the learning rate.
You can also use the LearningRateScheduler callback to dynamically adjust the learning rate. This is also a common method.

reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=10,verbose=1)
early_stop =keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10,verbose=1)
save_weights = keras.callbacks.ModelCheckpoint(save_path + "/model_{epoch:02d}_{val_accuracy:.4f}.h5",
                                                   save_best_only=True, monitor='val_accuracy')

reference documents

When learning a deep learning framework, it is very useful to read official documents. Here are two that are often used, or in other words, it is enough to read these two.
Keras API reference
tensorflow
also provides some introductory tutorials, and you will have a certain understanding of the framework by following along.

Guess you like

Origin blog.csdn.net/qtzbxjg/article/details/128619227
tf2