LeNet 네트워크 소개

1. 배경

이 기사에서는 주로 CIFAR-10 이미지 데이터 세트에 대한 LeNet 네트워크 예측의 훈련 및 예측을 소개합니다.

2. CIFAR-10 이미지 데이터 세트 소개

        CIFAR-10은 32*32 픽셀의 6W 3채널 컬러 이미지가 포함된 데이터 세트로, 이미지는 10개의 카테고리로 구분되며, 각 카테고리에는 6K 이미지가 포함되어 있습니다. 그 중 훈련 세트에는 50,000개의 이미지가 있고 테스트 세트에는 10,000개의 이미지가 있습니다.

데이터 로딩 및 전처리:

def load_and_proc_data():
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    print('X_train shape', X_train.shape)
    # X_train shape (50000, 32, 32, 3)
    print(X_train.shape[0], 'train samples')
    print(X_test.shape[0], 'test samples')

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    # 将类向量转换成二值类别矩阵
    y_train = np_utils.to_categorical(y_train, NB_CLASSES)
    y_test = np_utils.to_categorical(y_test, NB_CLASSES)
    return X_train, X_test, y_train, y_test

3. LeNet 네트워크 모델 정의

3.1 단일 레이어 콘볼루션 네트워크

from keras.models import Sequential
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense, Dropout
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.optimizers import RMSprop

class LeNet:
    @staticmethod
    def build(input_shape, classes):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))

        model.add(Dense(classes))
        model.add(Activation('softmax'))
        model.summary()  # 概要汇总网络
        return model

3.2 모델 구조 및 관련 매개변수

  3.3 모델 깊이 증가(다층 컨볼루션)

class LeNet:
    @staticmethod
    def build(input_shape, classes):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))
        # model.add(Conv2D(32, (3, 3), padding='same', input_shape=X_train.shape[1:]))  # (32, 32, 3)
        model.add(Activation('relu'))
        model.add(Conv2D(32, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(64, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(Conv2D(64, kernel_size=3, padding='same'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
        model.add(Dropout(0.25))
        
        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))

        model.add(Dense(classes))
        model.add(Activation('softmax'))
        model.summary()  # 概要汇总网络
        return model

4. 모델 훈련 및 예측

def model_train(X_train, y_train):
    OPTIMIZER = RMSprop(lr=0.0001, decay=1e-6)
    model = LeNet.build(input_shape=INPUT_SHAPE, classes=NB_CLASSES)
    model.compile(loss='categorical_crossentropy', optimizer=OPTIMIZER, metrics=['accuracy'])
    history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=1, validation_split=VALIDATION_SPLIT)
    # model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=1, validation_data=(X_test, y_test),shuffle=True)
    # plot_picture(history)
    return model

def model_evaluate(model, X_test, y_test):
    score = model.evaluate(X_test, y_test, batch_size=BATCH_SIZE, verbose=1)
    print('Test score: ', score[0])
    print('Test acc: ', score[1])

5. 인쇄 정확도 및 손실 기능

import matplotlib.pyplot as plt

def plot_picture(history):
    print(history.history.keys())
    # -----------acc---------------
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('model acc')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
    # -----------loss---------------
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

 6. 모델 저장

def model_save(model):
    # 保存网络结构
    model_json = model.to_json()
    with open('cifar10_architecture.json', 'w') as f:
        f.write(model_json)
    # 保存网络权重
    model.save_weights('cifar10_weights.h5', overwrite=True)

7. 주요 기능

NB_EPOCH = 50
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.2
IMG_ROWS, IMG_COLS = 32, 32
IMG_CHANNELS = 3
INPUT_SHAPE = (IMG_ROWS, IMG_COLS, IMG_CHANNELS)  # 注意顺序
NB_CLASSES = 10

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = load_and_proc_data()
    model = model_train(X_train, y_train)
    # model_save(model)
    model_evaluate(model, X_test, y_test)

모델 출력

시험 점수: 1.3542113304138184
시험 점수: 0.6733999848365784

8. 모델 로딩 및 온라인 추론

모델이 훈련된 후 모델 파일에서 모델을 로드하고 예측합니다.

import numpy as np
from keras.models import model_from_json
from keras.optimizers import SGD
from skimage.transform import resize
import imageio

def input_data_proc():
    img_names = ['cat.png', 'dog.png']
    img_list = []
    for img_name in img_names:
        img = imageio.imread(img_name)
        img = resize(img, output_shape=(32, 32, 3)).astype('float32')
        print('size: ', img.shape)
        img_list.append(img)
    img_list = np.array(img_list) / 255
    return img_list

def model_load(model_json, model_weight):
    model = model_from_json(open(model_json).read())
    model.load_weights(model_weight)
    return model

def model_predict(model, optim, img_list):
    model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])
    preds = model.predict(img_list)
    preds = np.argmax(preds, axis=1)
    print(preds)

if __name__ == '__main__':
    model_json = 'cifar10_architecture.json'
    model_weight = 'cifar10_weights.h5'
    model = model_load(model_json, model_weight)

    optim = SGD()
    img_list = input_data_proc()
    model_predict(model, optim, img_list)

Supongo que te gusta

Origin blog.csdn.net/MusicDancing/article/details/130261584
Recomendado
Clasificación