tensorflow搭建自己的残差网络(ResNet)

废话不说,直接上代码:

首先

pip install tflearn

训练代码

# -*- coding: utf-8 -*-  

from __future__ import division, print_function, absolute_import  

import tflearn  

# Residual blocks  
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
n = 5  

#numClass看你是几分类
numClass = 10

#这里需要用户自己得到(X, Y), (validationX, validationY)
(X, Y), (validationX, validationY)
Y = tflearn.data_utils.to_categorical(Y, numClass )  
testY = tflearn.data_utils.to_categorical(testY, numClass )  

# Real-time data preprocessing  
img_prep = tflearn.ImagePreprocessing()  
img_prep.add_featurewise_zero_center(per_channel=True)  

# Real-time data augmentation  
img_aug = tflearn.ImageAugmentation()  
img_aug.add_random_flip_leftright()  
img_aug.add_random_crop([256, 256], padding=4)  

# Building Residual Network  
net = tflearn.input_data(shape=[None, 256, 256, 3],  
                         data_preprocessing=img_prep,  
                         data_augmentation=img_aug)  
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)  
net = tflearn.residual_block(net, n, 16)  
net = tflearn.residual_block(net, 1, 32, downsample=True)  
net = tflearn.residual_block(net, n-1, 32)  
net = tflearn.residual_block(net, 1, 64, downsample=True)  
net = tflearn.residual_block(net, n-1, 64)  
net = tflearn.batch_normalization(net)  
net = tflearn.activation(net, 'relu')  
net = tflearn.global_avg_pool(net)  
# Regression  
net = tflearn.fully_connected(net, numClass, activation='softmax')  
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)  
net = tflearn.regression(net, optimizer=mom,  
                         loss='categorical_crossentropy')  
# Training  
model = tflearn.DNN(net, checkpoint_path='model_resnet_mymodel',  
                    max_checkpoints=10, tensorboard_verbose=0,  
                    clip_gradients=0.)  

model.fit(X, Y, n_epoch=200, validation_set=(validationX, validationY),  
          snapshot_epoch=False, snapshot_step=500,  
          show_metric=True, batch_size=128, shuffle=True,  
          run_id='resnet_mymodel')  

(X, Y), (validationX, validationY)分别表示训练和验证的数据和标签,具体代码需要自己实现。numClass表示是几分类。

测试代码

from __future__ import division, print_function, absolute_import  

import tflearn  
import numpy as np
from PIL import Image 
import os


def buildModel():
    numClass = 10
    # Residual blocks  
    # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
    n = 5 

# Real-time data preprocessing  
    img_prep = tflearn.ImagePreprocessing()  
    img_prep.add_featurewise_zero_center(per_channel=True)  

    # Real-time data augmentation  
    img_aug = tflearn.ImageAugmentation()  
    img_aug.add_random_flip_leftright()  
    img_aug.add_random_crop([256, 256], padding=4)  


    # Building Residual Network  
    net = tflearn.input_data(shape=[None, 256, 256, 3],  
                         data_preprocessing=img_prep,  
                         data_augmentation=img_aug)  
    net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)  
    net = tflearn.residual_block(net, n, 16)  
    net = tflearn.residual_block(net, 1, 32, downsample=True)  
    net = tflearn.residual_block(net, n-1, 32)  
    net = tflearn.residual_block(net, 1, 64, downsample=True)  
    net = tflearn.residual_block(net, n-1, 64)  
    net = tflearn.batch_normalization(net)  
    net = tflearn.activation(net, 'relu')  
    net = tflearn.global_avg_pool(net)  
    # Regression  
    net = tflearn.fully_connected(net, numClass, activation='softmax')  
    mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)  
    net = tflearn.regression(net, optimizer=mom,  
                         loss='categorical_crossentropy')  
    # Training  

    model = tflearn.DNN(net,tensorboard_verbose=0, clip_gradients=0.) 
    #写入你保存model的路径
    model.load(model_file=YourPath, weights_only=False)
    return model
def predicMsk(picPath,model):

    # Data loading  
    test = []
    image = Image.open(picPath)
    image = image.resize([256, 256])
    image = np.array(image)
    test.append(image/255)
    test = np.array(test)

    a = model.predict(test)
    return a

这样就可以用残差网络处理自己的数据集了。

猜你喜欢

转载自blog.csdn.net/qq_34484472/article/details/77848091