TensorFlow --卷积神经网络之 tf.nn.conv2d与tf.nn.max_pool

最近在研究学习TensorFlow,在做识别手写数字时,遇到了tf.nn.conv2d这个方法,其中有些方法还不是很清楚,于是网上搜索后,记录如下:

卷积神经网络的核心是对图像的“卷积”操作

tf.nn.conv2d方法定义

tf.nn.conv2d (input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)(官网参数)

参数:

  • input : 输入的要做卷积的图片,要求为一个张量,shape为 [ batch, in_height, in_weight, in_channel ],其中batch为图片的数量,in_height 为图片高度,in_weight 为图片宽度,in_channel 为图片的通道数,灰度图该值为1,彩色图为3。
  • filter: 卷积核,要求也是一个张量,shape为 [ filter_height, filter_weight, in_channel, out_channels ],其中 filter_height 为卷积核高度,filter_weight 为卷积核宽度,in_channel 是图像通道数 ,和 input 的 in_channel 要保持一致,out_channel 是卷积核个数
  • strides: 卷积时在图像每一维的步长,这是一个一维的向量,[ 1, strides, strides, 1],第一位和最后一位固定必须是1。(代表步长,其值可以直接默认一个数,也可以是一个四维数如[1,2,1,1],则其意思是水平方向卷积步长为第二个参数2,垂直方向步长为1.)
  • padding: string类型,值为“SAME” 和 “VALID”,表示的是卷积的形式,是否考虑边界。”SAME”是考虑边界,不足的时候用0去填充周围,”VALID”则不考虑
  • use_cudnn_on_gpu: bool类型,是否使用cudnn加速,默认为true

这个op(conv2d)执行了以下操作

1 将filter转为二维矩阵
它的shape是[filter_height * filter_width * in_channels, output_channels].

2 从input tensor中提取image patches(小块),形成一个virtual tensor,
它的shape是[batch, out_height, out_width, filter_height * filter_width * in_channels].

3 filter矩阵和image patch向量相乘

一般要求 strides的参数,strides[0] = strides[3] = 1

  1. strides[0] 控制 batch
  2. strides[1] 控制 height
  3. strides[2] 控制 width
  4. strides[3] 控制 channels
  • strides[0] = 1,在 batch 维度上的移动 1
  • strides[3] = 1,在 channels 维度上的移动 1

图片卷积后的尺寸计算公式

  • 输入图片大小 W×W
  • Filter大小 F×F
  • 步长 S
  • padding的像素数 P

N = (W − F + 2P )/S+1

输出图片大小为 N×N

接下来通过实例来看通道、卷积核数目不同时,数据是如何变化的。

案例一(单通道、单核):

import tensorflow as tf
#case 1
input = tf.Variable(tf.random_normal([1,3,3,1]))
filter = tf.Variable(tf.random_normal([1,1,1,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print("case 1")
    print(sess.run(input))

case 1
[[[[ 0.00850411]
   [ 0.00713599]
   [-0.1402842 ]]

  [[-1.4874302 ]
   [ 1.1501638 ]
   [-0.27221245]]

  [[-1.8692739 ]
   [-1.0514828 ]
   [-0.22669399]]]]

Process finished with exit code 0

案例二(多通道、单核):

import tensorflow as tf
#case 2
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([1,1,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print("case 2")
    print(sess.run(input))

case 2
[[[[-0.7135731 ]
   [ 0.33953804]
   [ 2.2308116 ]]

  [[ 0.99762535]
   [-1.1370671 ]
   [ 1.024965  ]]

  [[ 1.1837609 ]
   [-0.01205832]
   [ 0.7789178 ]]]]

Process finished with exit code 0

从这个案例可以看出,多通道好像不体现在数据的形状上,从RGB中图片中,我们也联想到是sum(R+G+B)组合一个数据,也是符合情况的。

            

案例三(单通道、多核):

import tensorflow as tf
#case 3
input = tf.Variable(tf.random_normal([1,3,3,1]))
filter = tf.Variable(tf.random_normal([1,1,1,5]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print("case 3")
    print(sess.run(input))

case 3
[[[[-0.5227298  -0.66182023  1.0833362   0.72718143 -0.17323467]
   [-0.24752371 -0.313386    0.5129827   0.34433588 -0.0820303 ]
   [-0.34527072 -0.43714198  0.7155594   0.480314   -0.11442404]]

  [[-0.23402141 -0.29629093  0.4849998   0.32555252 -0.0775556 ]
   [ 0.25400987  0.321598   -0.52642506 -0.35335892  0.08417984]
   [-0.5669042  -0.71774876  1.1748857   0.7886334  -0.18787423]]

  [[ 0.9459997   1.1977158  -1.9605457  -1.3160018   0.31350794]
   [ 0.26897132  0.34054047 -0.55743206 -0.37417215  0.08913813]
   [ 0.4315578   0.54638875 -0.8943859  -0.6003499   0.14301991]]]]

Process finished with exit code 0

从这里好像看出,多核的数目反应在行上

案例四(多通道、多核):

import tensorflow as tf
#case 4
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([1,1,5,5]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print("case 4")
    print(sess.run(input))

case 4
[[[[-2.4651055  -0.3509455  -4.686282   -2.3212051  -0.710496  ]
   [ 0.7768164   0.57705945  1.318051    0.5697315   0.07564557]
   [-1.0820991   1.9157095   0.23172522  2.5040576   2.321178  ]]

  [[ 0.4635254   0.05687328  0.81018573 -0.14844202 -1.0212026 ]
   [ 2.6249065  -0.76527166  4.834874    2.1821246   0.7100087 ]
   [ 0.6223385  -0.1788317   4.873171    3.3820043  -1.1162739 ]]

  [[-1.0766833   0.43714532 -0.9590479  -0.6471283  -1.8756338 ]
   [-3.936176   -0.02507877 -7.966227   -3.140205    0.98988634]
   [-4.6462646  -0.8348821  -6.4950304  -1.1865668   0.99817204]]]]

Process finished with exit code 0

接下来的案例继续说明问题。他们是如何计算的,以及计算结果是如何排列的

import tensorflow as tf
sess = tf.InteractiveSession()
input_batch = tf.constant([
        [  # First Input (6x6x1)
            [[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]],
            [[0.1], [1.1], [2.1], [3.1], [4.1], [5.1]],
            [[0.2], [1.2], [2.2], [3.2], [4.2], [5.2]],
            [[0.3], [1.3], [2.3], [3.3], [4.3], [5.3]],
            [[0.4], [1.4], [2.4], [3.4], [4.4], [5.4]],
            [[0.5], [1.5], [2.5], [3.5], [4.5], [5.5]],
        ],
    ])
kernel = tf.constant([  # Kernel (3x3x1)
        [[[0.0]], [[0.5]], [[0.0]]],
        [[[0.0]], [[1.0]], [[0.0]]],
        [[[0.0]], [[0.5]], [[0.0]]]
    ])
# NOTE: the change in the size of the strides parameter.
conv2d = tf.nn.conv2d(input_batch, kernel, strides=[1, 3, 3, 1], padding='SAME')
print(sess.run(conv2d))

输出结果

[[[[ 2.20000005] 
[ 8.19999981]]

[[ 2.79999995] 
[ 8.80000019]]]]

如果将上述例子更改为 strides=[1, 2, 2, 1]

结果输出 
[[[[  2.20000005] 
   [  6.19999981] 
   [ 10.19999981]]

[[  2.5999999 ] 
   [  6.60000038] 
   [ 10.60000038]]

[[  2.20000005] 
   [  5.19999981] 
   [  8.19999981]]]]

案例六(多通道、多核):

#case 6
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
op6 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')

case 6
[[[[ 12.02504349   4.35077286   2.67207813   5.77893162   6.98221684
     -0.96858567  -8.1147871 ]
   [ -0.02988982  -2.52141953  15.24755192   6.39476395  -4.36355495
     -2.34515095   5.55743504]
   [ -2.74448752  -1.62703776  -6.84849405  10.12248802   3.7408421
      4.71439075   6.13722801]
   [  0.82365227  -1.00546622  -3.29460764   5.12690163  -0.75699937
     -2.60097408  -8.33882809]
   [  0.76171923  -0.86230004  -6.30558443  -5.58426857   2.70478535
      8.98232937  -2.45504045]]

  [[  3.13419819 -13.96483231   0.42031103   2.97559547   6.86646557
     -3.44916964  -0.10199898]
   [ 11.65359879  -5.2145977    4.28352737   2.68335319   3.21993709
     -6.77338028   8.08918095]
   [  0.91533852  -0.31835344  -1.06122255  -9.11237717   5.05267143
      5.6913228   -5.23855162]
   [ -0.58775592  -5.03531456  14.70254898   9.78966522 -11.00562763
     -4.08925819  -3.29650426]
   [ -2.23447251  -0.18028721  -4.80610704  11.2093544   -6.72472
     -2.67547607   1.68422937]]

  [[ -3.40548897  -9.70355129  -1.05640507  -2.55293012  -2.78455877
    -15.05377483  -4.16571808]
   [ 13.66925812   2.87588191   8.29056358   6.71941566   2.56558466
     10.10329056   2.88392687]
   [ -6.30473804  -3.3073864   12.43273926  -0.66088223   2.94875336
      0.06056046  -2.78857946]
   [ -7.14735603  -1.44281793   3.3629775   -7.87305021   2.00383091
     -2.50426936  -6.93097973]
   [ -3.15817571   1.85821593   0.60049552  -0.43315536  -4.43284273
      0.54264796   1.54882073]]

  [[  2.19440389  -0.21308756  -4.35629082  -3.62100363  -0.08513772
     -0.80940366   7.57606506]
   [ -2.65713739   0.45524287 -16.04298019  -5.19629049  -0.63200498
      1.13256514  -6.70045137]
   [  8.00792599   4.09538221  -6.16250181   8.35843849  -4.25959206
     -1.5945878   -7.60996151]
   [  8.56787586   5.85663748  -4.38656425   0.12728286  -6.53928804
      2.3200655    9.47253895]
   [ -6.62967777   2.88872099  -2.76913023  -0.86287498  -1.4262073
     -6.59967232   5.97229099]]

  [[ -3.59423327   4.60458899  -5.08300591   1.32078576   3.27156973
      0.5302844   -5.27635145]
   [ -0.87793881   1.79624665   1.66793108  -4.70763969  -2.87593603
     -1.26820421  -7.72825718]
   [ -1.49699068  -3.40959787  -1.21225107  -1.11641395  -8.50123024
     -0.59399474   3.18010235]
   [ -4.4249506   -0.73349547  -1.49064219  -6.09967899   5.18624878
     -3.80284953  -0.55285597]
   [ -1.42934585   2.76053572  -5.19795799   0.83952439  -0.15203482
      0.28564462   2.66513705]]]]

以下是区别padding='VALID',不考虑边界的情况

input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print("case 7")
    print(sess.run(op2))

case 7
[[[[-12.277221    -3.6852837   -7.5415998   -1.4435571    4.5265384
      5.6599164    0.03461173]
   [  4.2244606   -1.8150828   -2.9772494   11.986962     1.5673934
     -5.33732     -6.576837  ]
   [  2.792845    -1.1091218   -8.66483     12.438319    -1.8882469
     -3.9440742   -6.3208795 ]]

  [[ -2.3882375    9.021189    -7.999711    18.31005      4.852937
     -5.7791305    5.0236855 ]
   [  1.0881239   -5.179409     0.15859601   6.445263     8.557671
    -16.044416     3.657256  ]
   [  2.795134     4.8999724   -9.92672      3.9908109    6.207695
     -6.553004     9.258662  ]]

  [[ -5.4560223    6.153165     6.02847      6.907523    -5.5059247
     -2.2264066    1.7103047 ]
   [ -1.0343044   -5.2060676    0.98752177  -4.918023     0.17576812
     -1.5359226    1.663869  ]
   [ -7.092221     1.1528535   -1.7145716    3.2233562   -4.150458
      0.8865322   14.828557  ]]]]

Process finished with exit code 0

tf.nn.max_pool方法定义

tf.nn.max_pool(value, ksize, strides, padding, name=None)

value:池化的输入,一般池化层接在卷积层的后面,所以输出通常为feature map。feature map依旧是[batch, in_height, in_width, in_channels]这样的参数。

ksize池化窗口的大小,参数为四维向量,通常取[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1。ps:估计面tf.nn.conv2d中stries的四个取值也有相同的意思。

stries:步长,同样是一个四维向量。

padding:填充方式同样只有两种不重复了。

               

参考资料:

https://blog.csdn.net/flyfish1986/article/details/77508783

猜你喜欢

转载自blog.csdn.net/qq_20412595/article/details/82855728