Detailed process of training its own image classifier (Xception, cifar10)

Particularly rapid spread of pneumonia two days, and made panic. . . Will humans after school almost gone

Hereinafter keras default tf.keras

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt

Data Processing Set

Object classification data set: cifar10

Reference: (x_train, y_train), (x_test, android.permission.FACTOR.) = Keras.datasets.cifar10.load_data ()

Not downloaded to download, ide download slow download Thunder can copy the link, and then copied to the user name /.keras inside.

Print part of the data:

# 展示一些图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
    plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index],
                       interpolation='nearest')  # 缩放图片时的方法
            plt.axis('off')
            plt.title(class_names[int(y_data[index])])
    plt.show()
class_names = ['plane', 'car' ,' bird ', 'cat','deer', 'dog', 'frog',
               'horse', 'boat', 'truck']
show_imgs(5, 8, x_train, y_train, class_names)

Here Insert Picture Description

Feature extraction for the network input channel 3, the size of at least 71 * 71, but only miss the data set cifar 28 * 28, so to reshape it:

x_train=tf.image.resize(x_train,[71,71])
x_test=tf.image.resize(x_test,[71,71])

Network architecture

The official document: https: //keras.io/zh/applications/
Here Insert Picture Description

Loading Xecption network:

xception=keras.applications.xception.Xception(
					include_top=False, weights='imagenet',
                   	input_tensor=None, input_shape=[71,71,3],
                  	pooling='avg')

Such Xception layer 133, the final output layer is: a one-dimensional tensor length 2048

We loaded weights imagenet of pre-training, recall ng, then freeze in front of the layer, layer behind the training:

    freeze_layer = 100
    for layer in xception.layers[:freeze_layer]:
        layer.trainable = False

Build your own network:

    model = keras.Sequential([
        xception,
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1024, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])

You can print detailed information about the model:

def print_model_info(model):
    for layer in model.layers:
        print('layer name:',layer.name,'output:',layer.output_shape,'trainable:',layer.trainable)
        if layer.name=='xception':
            for ll in layer.layers:
                print('layer name:',ll.name,'output:',ll.output_shape,'trainable:',ll.trainable)
        else:
            for weight in layer.weights:
                print('--weight name:',weight.name,'shape:',weight.shape)

conpile:

model.compile(optimizer=keras.optimizers.Adam(lr=0.0001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Training process

If the entire data set found that epoch to run a few minutes. And may not be progress, for example, do not increase the accuracy of the test set, the smaller learning rate, a larger learning rate.

Therefore, the data set into a plurality of portions, each stored locally after training.

n = x_train.shape[0]
m = x_test.shape[0]
n = (int)(n / 10)
m = (int)(m / 10)

# 分批训练并及时保存,以实时调整学习速率等其他超参数
for epoch in range(10):
    for i in range(10):
        # 训练
        ns = i * n
        ms = i * m
        history = model.fit(x_train[ns:ns + n], y_train[ns:ns + n],
                      batch_size=64, epochs=1, 
                      validation_data=(x_test[ms:ms + m], y_test[ms:ms + m]))
        model.save('my_model.h5')

loss, acc = model.evaluate(x_test, y_test)
print("model loss:{:.2}, acc:{:.2%}".format(loss, acc))

After the training can continue training on previous results

    import os
    if os.path.exists('my_model.h5'):
        return tf.keras.models.load_model('my_model.h5')

Adjust the network structure

Find such a given structure test set accuracy can not go up, then thought:

cifar set small in size, a resize the image to get the picture and imagenet 71 * 71 should be much worse, so the front Xception layers of the extracted feature does not apply in this training set.

While the back layers of the network architecture will be appreciated that a method for the preceding layers of tissue characteristics, so here try frozen layer 100 behind, in front of the train.

    freeze_layer = 100
    for layer in xception.layers[-freeze_layer:]:
        layer.trainable = False

The complete code

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt


# 展示学习曲线
def show_learning_curves(history):
    import pandas as pd
    pd.DataFrame(history.history).plot(figsize=(8, 5))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()


# 获取模型
def getModel():
    import os
    if os.path.exists('my_model.h5'):
        return tf.keras.models.load_model('my_model.h5')
    # 特征提取网络
    xception = keras.applications.xception.Xception(include_top=False, weights='imagenet',
                                                    input_tensor=None, input_shape=[71, 71, 3],
                                                    pooling='avg')

    # xception.trainable=False
    freeze_layer = 100
    for layer in xception.layers[-freeze_layer:]:
        layer.trainable = False

    # 模型创建和训练
    model = keras.Sequential([
        xception,
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1024, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])
    return model


def print_model_info(model):
    for layer in model.layers:
        print('layer name:', layer.name, 'output:', layer.output_shape, 'trainable:', layer.trainable)
        if layer.name == 'xception':
            for ll in layer.layers:
                print('layer name:', ll.name, 'output:', ll.output_shape, 'trainable:', ll.trainable)
        else:
            for weight in layer.weights:
                print('--weight name:', weight.name, 'shape:', weight.shape)


model = getModel()
# print_model_info(model)

model.compile(optimizer=keras.optimizers.Adam(lr=0.0005),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 获取数据
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print('iamge ori size:',x_train.shape, y_train.shape, x_test.shape, y_test.shape)


# 展示一些图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
    plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index],
                       interpolation='nearest')  # 缩放图片时的方法
            plt.axis('off')
            plt.title(class_names[int(y_data[index])])
    plt.show()


class_names = ['plane', 'car', ' bird ', 'cat', 'deer', 'dog', 'frog',
               'horse', 'boat', 'truck']
# show_imgs(5, 8, x_train, y_train, class_names)

x_train = tf.image.resize(x_train, [71, 71])
x_test = tf.image.resize(x_test, [71, 71])
print('image processed',x_train.shape, y_train.shape, x_test.shape, y_test.shape)

n = x_train.shape[0]
m = x_test.shape[0]
n = (int)(n / 10)
m = (int)(m / 10)

# 分批训练并及时保存,以实时调整学习速率等其他超参数
for epoch in range(10):
    for i in range(10):
        # 训练
        ns = i * n
        ms = i * m
        history = model.fit(x_train[ns:ns + n], y_train[ns:ns + n],
                            batch_size=64, epochs=1, validation_data=(x_test[ms:ms + m], y_test[ms:ms + m]))
        model.save('my_model.h5')
        # show_learning_curves(history)

loss, acc = model.evaluate(x_test, y_test)
print("model loss:{:.2}, acc:{:.2%}".format(loss, acc))

Published 723 original articles · won praise 314 · views 160 000 +

Guess you like

Origin blog.csdn.net/jk_chen_acmer/article/details/104099357