CNN练手——手写体识别

# 卷积层的实现函数

def convolutional_layer(input, num_input_channels, filter_size, num_filters, use_pooling=True):
    # 前两个参数是过滤器的尺寸,第三个参数是输入的通道,第四个参数是输出的通道,也就是过滤器的个数
    shape = [filter_size, filter_size, num_input_channels, num_filters]
    weights = tf.Variable(tf.truncated_normal(shape, stddev=0.05))
    # 1*1*num_filters
    biases = tf.Variable(tf.constant(0.05, shape=[num_filters]))
    # 卷积层
    # input是上一层的输出
    # filter指的就是卷积核
    # strides,第一个和最后一个必须为1,中间1*1代表步长
    # padding等于SAME表示大小不变,也就是使用零填充
    layer = tf.nn.conv2d(input=input, filter=weights, strides=[1, 1, 1, 1], padding='SAME')
    layer += biases

    if use_pooling:  # 如果使用池化层
        # ksize表示池化窗口的大小,第一个和最后一个必须为1,中间的2*2表示窗口大小
        # strides、padding的设置和卷积层一样
        layer = tf.nn.max_pool(value=layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    layer = tf.nn.relu(layer)

    return layer, weights


# 作用就是将图片转为一个一维的向量
def flaten_layer():
    layer_shape = layer.get_shape()
    num_features = layer_shape[1:4].num_elements()
    layer_flat = tf.reshape(layer, [-1, num_features])
    return layer_flat, num_features


# 定义全连接层
def full_connected_layer(input, num_inputs, num_outputs, use_relu=True):

    # Create new weights and biases.
    weights = tf.Variable(tf.truncated_normal(shape=[num_inputs, num_outputs], stddev=0.05))
    biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))

    # input [num_inputs] * weights [num_inputs, num_outputs] + biases [num_outputs] = [num_outputs]
    layer = tf.matmul(input, weights) + biases

    # Use ReLU?
    if use_relu:
        layer = tf.nn.relu(layer)

    return layer


# 将图片转化为一维向量
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
# 将图片转化为img_size*img_size的三维张量,并输入到卷积层中
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])
# 
y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
# 通过argmax计算得到图片对应的数字
y_true_cls = tf.argmax(y_true, axis=1)


cov1_layer = convolutional_layer(input=x_image, num_input_channels=num_channels, filter_size=filter_size)

猜你喜欢

转载自www.cnblogs.com/flyangovoyang/p/10618562.html