利用tensorflow训练自己的图片——2、网络搭建(AlexNet)

得到数据之后,接下来就是网络的搭建,我在这里将模型单独定义出来,方便后期的网络修正。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
Spyder Editor
This is a temporary script file.
filename:DR_alexnet.py
creat time:2018年1月16日
author:huangxudong
"""
import tensorflow as tf
import numpy as np
#define different layer function
def maxPoolLayer(x,kHeight,kWidth,strideX,strideY,name,padding="SAME"):
    return tf.nn.max_pool(x,ksize=[1,kHeight,kWidth,1],strides=[1,strideX,strideY,1],padding=padding,name=name)

def dropout(x,keepPro,name=None):
    return tf.nn.dropout(x,keepPro,name)

def LRN(x,R,alpha,beta,name=None,bias=1.0):                      #局部相应归一化
    return tf.nn.local_response_normalization(x,depth_radius=R,alpha=alpha,
                                              beta=beta,bias=bias,name=name)
def fcLayer(x,inputD,outputD,reluFlag,name):
    with tf.variable_scope(name) as scope:
        w=tf.get_variable("w",shape=[inputD,outputD])   #shape就是变量维度
        b=tf.get_variable("b",[outputD])
        out=tf.nn.xw_plus_b(x,w,b,name=scope.name)
        if reluFlag:
            return tf.nn.relu(out)
        else:
            return out
def convLayer(x,kHeight,kWidth,strideX,strideY,featureNum,name,padding="SAME",groups=1):
    """convolution"""
    channel=int(x.get_shape()[-1])       #x数组的最后一个数
    conv=lambda a,b: tf.nn.conv2d(a,b,strides=[1,strideY,strideX,1],padding=padding)   #匿名函数
    with tf.variable_scope(name) as scope:
        w=tf.get_variable("w",shape=[kHeight,kWidth,channel/groups,featureNum])
        b=tf.get_variable("b",shape=[featureNum])
        xNew=tf.split(value=x,num_or_size_splits=groups,axis=3)
        wNew=tf.split(value=w,num_or_size_splits=groups,axis=3)
        featureMap=[conv(t1,t2) for t1,t2 in zip(xNew,wNew)]
        mergeFeatureMap=tf.concat(axis=3,values=featureMap)
        out=tf.nn.bias_add(mergeFeatureMap,b)
 #       print(mergeFeatureMap.get_shape().as_list(),out.shape)
        return tf.nn.relu(out,name=scope.name)  #卷积激活一起完成,out大小和mergeFeatureMap一样,不需要reshape
class alexNet(object):
    """alexNet model"""
    def __init__(self,x,keepPro,classNum,skip,modelPath="bvlc_alexnet.npy"):
        self.X=x
        self.KEEPPRO=keepPro                 #表示类名
        self.CLASSNUM=classNum
        self.SKIP=skip
        self.MODELPATH=modelPath
        #build CNN
        self.buildCNN()
    def buildCNN(self):             #重点,模型搭建  2800*2100
        x1=tf.reshape(self.X,shape=[-1,512,512,3])
#        print(x1.shape)
        conv1=convLayer(x1,7,7,3,3,256,"conv1","VALID")    #169*169
        lrn1=LRN(conv1,2,2e-05,0.75,"norm1")
        pool1=maxPoolLayer(lrn1,3,3,2,2,"pool1","VALID")    #84*84

        conv2=convLayer(pool1,3,3,1,1,512,"conv2","VALID")    #82*82
        lrn2=LRN(conv2,2,2e-05,0.75,"norm2")
        pool2=maxPoolLayer(lrn2,3,3,2,2,"pool2","VALID")       #40*40
    
        conv3=convLayer(pool2,3,3,1,1,1024,"conv3","VALID")    #38*38    
        conv4=convLayer(conv3,3,3,1,1,1024,"conv4","VALID")   #36*36

        conv5=convLayer(conv4,3,3,2,2,512,"conv5","VALID")    #17*17
        pool5=maxPoolLayer(conv5,3,3,2,2,"pool5","VALID")     #8*8
#        print(pool5.shape)
        fcIn=tf.reshape(pool5,[-1,512*8*8])
        fc1=fcLayer(fcIn,512*8*8,4096,True,"fc6")
        dropout1=dropout(fc1,self.KEEPPRO)

        fc2=fcLayer(dropout1,4096,4096,True,"fc7")
        dropout2=dropout(fc2,self.KEEPPRO)

        self.fc3=fcLayer(dropout2,4096,self.CLASSNUM,True,"fc8")
上面便是网络的搭建,搭好之后还需要将模型加载出来:

    def loadModel(self,sess):
        """load model"""
        wDict=np.load(self.MODELPATH,encoding="bytes").item()
        for name in wDict:
            if name not in self.SKIP:
                with tf.variable_scope(name, reuse = True):
                    for p in wDict[name]:
                        if len(p.shape) == 1:
                            #bias
                            sess.run(tf.get_variable('b', trainable = False).assign(p))
                        else:
                            #weights
                            sess.run(tf.get_variable('w', trainable = False).assign(p))


猜你喜欢

转载自blog.csdn.net/qq_36631272/article/details/79173224
今日推荐