4.1吴恩达深度学习笔记之利用Tensorflow构建Cnn模型

给定CNN结构:conv2d--relu--MaxPool--conv2d--relu--MaxPool--FullConnected

1.数据集预处理:和普通神经网络不同的是,CNN中输入的训练集和的测试集只需进行单位化处理而不需要flatten,因为卷积过程并没有化矩阵为向量,而是对矩阵进行处理。对于多维输出Y需要进行one_hot处理将其变为对应矩阵。

2.定义相关函数:

    2.1设置占位符:由于tensorflow的模式,需要为输入X,输出Y,这里的X,Y的第一维数用None表示,因为还不清楚后续过程穿过来的数据有多少个样本。

代码如下:X=tf.placeholders(tf.float32,[None,n_H0,n_W0,n_C0])

                    Y=tf.placeholders(tf.float32,[None,n_y])

   2.2初始化参数:根据给定的条件初始化W1,W2(以两层CNN为例),b不用初始化,tensorflow会自动生成,FC层也是如此。W的维度为(f,f,n_C_prev,n_C)

   2.3前向传播:利用tf的内置函数进行前向传播以第一层为例:

      代码如下:

     卷积层---- Z1=tf.nn.conv2d(X,W1,strides=[1,1,1,1],padding='SAME')

     激活函数--A1=tf.nn.relu(Z1)

     池化层----P1=tf.nn.max_pool(A1,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')

     最后进过第二层的池化后需要把P2 flatten,获得一个列向量,在进如全连接层

    flatten:P=tf.contrib.layers.flatten(P2)

    FullConnected:Z3=tf.contrib.layers.fully_connected(P,6,activation_fn=None)

    至此完成了前向传播所有操作

  2.4计算cost,tf中自带cost计算函数,将Z3,Y3传入函数即可

     cost=tf.reduc_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Z3,labels=Y))

   2.5反向传播与优化不用写函数,tf会自动进行

3 model:利用写好的函数搭建model

 3.1设置占位符,初始化参数,前向传播,计算cost,设置优化函数,这里写出优化函数的代码

   optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost),

   初始化所有变量:init=tf.global_variables_initializer()

   注意到此步位置完成所有需要用到的函数的定义,但均为进行数据传递也未运行。接下来定义一个session

3.2定义session:

    with tf.session() as sess:

            sess.run(init)  #运行全局初始化函数

            for epoch in range(num_epochs):

                    mini_cost=0

                    seed+=1

                    num_minibacthes=int(m/minibatch_size)

                    minibatches=random_minibacthes(X_train,Y_train,minibatch_size,seed)#设置minibatch    

                    for minibatch in minibatches:

                            minibatch_X,minibatch_Y=minibatch

                            #对每个minibatch运行以上函数

                            _,cost=sess.run([optimizer.cost],feed_dict={X:minibatch_X,Y:minibatch_Y})

                            minibatch_cost+=cost/num_minibatches


猜你喜欢

转载自blog.csdn.net/qq_40103460/article/details/80211942
今日推荐