トレーニングの詳細な処理独自画像分類器(Xception、cifar10)

肺炎2日間の特に急速な普及、および作らパニック。放課後の人間はほとんど消えています

以下kerasデフォルトtf.keras

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt

データ処理の設定

オブジェクトの分類データセット: cifar10

参照:(x_train、y_train)、(x_test、android.permission.FACTOR。)= Keras.datasets.cifar10.load_data()

ダウンロードにダウンロードされていない、IDEのダウンロードが遅いのダウンロードサンダーは、リンクをコピーして、ユーザー名/.keras内部にコピーすることができます。

データの一部を印刷:

# 展示一些图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
    plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index],
                       interpolation='nearest')  # 缩放图片时的方法
            plt.axis('off')
            plt.title(class_names[int(y_data[index])])
    plt.show()
class_names = ['plane', 'car' ,' bird ', 'cat','deer', 'dog', 'frog',
               'horse', 'boat', 'truck']
show_imgs(5, 8, x_train, y_train, class_names)

ここに画像を挿入説明

特徴ネットワーク入力チャンネル3、少なくとも71 * 71の大きさの抽出が、唯一それを再構築するために、* 28 28 cifarデータセットを見逃します。

x_train=tf.image.resize(x_train,[71,71])
x_test=tf.image.resize(x_test,[71,71])

ネットワークアーキテクチャ

公式文書ます。https://keras.io/zh/applications/
ここに画像を挿入説明

Xecptionネットワークのロード:

xception=keras.applications.xception.Xception(
					include_top=False, weights='imagenet',
                   	input_tensor=None, input_shape=[71,71,3],
                  	pooling='avg')

そのようなXception層133は、最終的な出力層である:一次元テンソル長さ2048

私たちは訓練の後ろの層、層の前に凍結し、その後、事前研修、リコールNGの重みimagenetをロード:

    freeze_layer = 100
    for layer in xception.layers[:freeze_layer]:
        layer.trainable = False

独自のネットワークを構築します。

    model = keras.Sequential([
        xception,
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1024, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])

あなたは、モデルの詳細情報を印刷することができます。

def print_model_info(model):
    for layer in model.layers:
        print('layer name:',layer.name,'output:',layer.output_shape,'trainable:',layer.trainable)
        if layer.name=='xception':
            for ll in layer.layers:
                print('layer name:',ll.name,'output:',ll.output_shape,'trainable:',ll.trainable)
        else:
            for weight in layer.weights:
                print('--weight name:',weight.name,'shape:',weight.shape)

conpile:

model.compile(optimizer=keras.optimizers.Adam(lr=0.0001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

トレーニング方法

データセット全体がそのエポックを発見した場合は、数分を実行します。そして、進行できない場合があり、例えば、テスト・セット、小さな学習率、大きな学習率の精度を増加させません。

したがって、複数の部分に設定されたデータは、各トレーニングの後にローカルに保存されました。

n = x_train.shape[0]
m = x_test.shape[0]
n = (int)(n / 10)
m = (int)(m / 10)

# 分批训练并及时保存,以实时调整学习速率等其他超参数
for epoch in range(10):
    for i in range(10):
        # 训练
        ns = i * n
        ms = i * m
        history = model.fit(x_train[ns:ns + n], y_train[ns:ns + n],
                      batch_size=64, epochs=1, 
                      validation_data=(x_test[ms:ms + m], y_test[ms:ms + m]))
        model.save('my_model.h5')

loss, acc = model.evaluate(x_test, y_test)
print("model loss:{:.2}, acc:{:.2%}".format(loss, acc))

訓練の後、以前の結果にトレーニングを続けることができます

    import os
    if os.path.exists('my_model.h5'):
        return tf.keras.models.load_model('my_model.h5')

ネットワーク構造を調整します

こうした特定の構造のテストセットの精度は思っその後、上がることができない検索します。

サイズはcifarセット小型、抽出した特徴のフロントXception層は、このトレーニングセットには適用されませんので、写真とimagenet 71 * 71を取得する画像は、はるかに悪化する必要がありリサイズ。

ネットワークアーキテクチャの裏層は、組織特性の前の層のための方法が理解されるであろうが、そうここに列車の前に、凍結層100の背後を試みます。

    freeze_layer = 100
    for layer in xception.layers[-freeze_layer:]:
        layer.trainable = False

完全なコード

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt


# 展示学习曲线
def show_learning_curves(history):
    import pandas as pd
    pd.DataFrame(history.history).plot(figsize=(8, 5))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()


# 获取模型
def getModel():
    import os
    if os.path.exists('my_model.h5'):
        return tf.keras.models.load_model('my_model.h5')
    # 特征提取网络
    xception = keras.applications.xception.Xception(include_top=False, weights='imagenet',
                                                    input_tensor=None, input_shape=[71, 71, 3],
                                                    pooling='avg')

    # xception.trainable=False
    freeze_layer = 100
    for layer in xception.layers[-freeze_layer:]:
        layer.trainable = False

    # 模型创建和训练
    model = keras.Sequential([
        xception,
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1024, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])
    return model


def print_model_info(model):
    for layer in model.layers:
        print('layer name:', layer.name, 'output:', layer.output_shape, 'trainable:', layer.trainable)
        if layer.name == 'xception':
            for ll in layer.layers:
                print('layer name:', ll.name, 'output:', ll.output_shape, 'trainable:', ll.trainable)
        else:
            for weight in layer.weights:
                print('--weight name:', weight.name, 'shape:', weight.shape)


model = getModel()
# print_model_info(model)

model.compile(optimizer=keras.optimizers.Adam(lr=0.0005),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 获取数据
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print('iamge ori size:',x_train.shape, y_train.shape, x_test.shape, y_test.shape)


# 展示一些图片
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):
    plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index],
                       interpolation='nearest')  # 缩放图片时的方法
            plt.axis('off')
            plt.title(class_names[int(y_data[index])])
    plt.show()


class_names = ['plane', 'car', ' bird ', 'cat', 'deer', 'dog', 'frog',
               'horse', 'boat', 'truck']
# show_imgs(5, 8, x_train, y_train, class_names)

x_train = tf.image.resize(x_train, [71, 71])
x_test = tf.image.resize(x_test, [71, 71])
print('image processed',x_train.shape, y_train.shape, x_test.shape, y_test.shape)

n = x_train.shape[0]
m = x_test.shape[0]
n = (int)(n / 10)
m = (int)(m / 10)

# 分批训练并及时保存,以实时调整学习速率等其他超参数
for epoch in range(10):
    for i in range(10):
        # 训练
        ns = i * n
        ms = i * m
        history = model.fit(x_train[ns:ns + n], y_train[ns:ns + n],
                            batch_size=64, epochs=1, validation_data=(x_test[ms:ms + m], y_test[ms:ms + m]))
        model.save('my_model.h5')
        # show_learning_curves(history)

loss, acc = model.evaluate(x_test, y_test)
print("model loss:{:.2}, acc:{:.2%}".format(loss, acc))

公開された723元の記事 ウォンの賞賛314 ビュー160 000 +

おすすめ

転載: blog.csdn.net/jk_chen_acmer/article/details/104099357