tensorflow ccn图片学习

 接着上一篇的 读取, 这篇是cnn的识别代码。 很多网上的代码其实都是可以直接跑的, 我只不过自己码了一遍 理解了一遍

# -*- coding: utf-8 -*-
import numpy as np

w_alpha=0.01
b_alpha=0.1

IMAGE_HEIGHT = 240
IMAGE_WIDTH = 320
MAX_CAPTCHA = 1
# 图片种类37
CHAR_SET_LEN = 37
dropout = 0.7


conv_dict = {
    # 第一层卷积参数 3*3, 因为是彩色图片 所以第一层输入通道是3, 输出为32
    "w_1": tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 32]), name='w_1'),
    "b_1": tf.Variable(b_alpha * tf.random_normal([32]), name='b_1'),
    # 第二层卷积参数
    "w_2": tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64]), name='w_2'),
    "b_2": tf.Variable(b_alpha * tf.random_normal([64]), name='b_2'),
    # 第三层卷积参数
    "w_3": tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 128]), name='w_3'),
    "b_3": tf.Variable(b_alpha * tf.random_normal([128]), name='b_3'),
    # 第四层卷积参数
    "w_4": tf.Variable(w_alpha * tf.random_normal([3, 3, 128, 128]), name='w_4'),
    "b_4": tf.Variable(b_alpha * tf.random_normal([128]), name='b_4'),

    'out': tf.Variable(tf.random_normal([1024, CHAR_SET_LEN])),
    'out_add': tf.Variable(tf.random_normal([CHAR_SET_LEN]))
}

# 批量标准化 - 防止 梯度弥散
# wx_plus_b tensor
# out_size  通道数
def batch_normal(wx_plus_b, out_size):
    fc_mean, fc_var = tf.nn.moments(
        wx_plus_b,
        axes=[0, 1, 2],  # 想要 normalize 的维度, [0] 代表 batch 维度
        # 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
    )
    # out_size 和wx_plus_b 输出通道数一致
    scale = tf.Variable(tf.ones([out_size]))
    shift = tf.Variable(tf.zeros([out_size]))
    epsilon = 0.001
    wx_plus_b = tf.nn.batch_normalization(wx_plus_b, fc_mean, fc_var, shift, scale, epsilon)
    return wx_plus_b



X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT , IMAGE_WIDTH,3])
Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN])
NOR = tf.placeholder(tf.float32)
keep_prob = tf.placeholder(tf.float32)  # dropout

# 把单个数变成数组
def one_hot_n(x, n):
    x = np.array(x)
    return np.eye(n)[x]


def conv2d(conv, cd1, cd2, out_size, nor):
    conv = tf.nn.bias_add(tf.nn.conv2d(conv, cd1, strides=[1, 1, 1, 1], padding='SAME'), cd2)
    # 做 batch_normal
#    if nor > 1:
#        conv = batch_normal(conv, out_size)
    conv = tf.nn.relu(conv)
    conv = tf.nn.max_pool(conv, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    # dropout 防止 过拟合
    conv = tf.nn.dropout(conv, keep_prob)
    return conv


# 定义CNN
def crack_captcha_cnn():

    # 四层卷积池化
    conv1 = conv2d(X, conv_dict['w_1'], conv_dict['b_1'], 32, NOR)
    conv2 = conv2d(conv1, conv_dict['w_2'], conv_dict['b_2'], 64, NOR)
    conv3 = conv2d(conv2, conv_dict['w_3'], conv_dict['b_3'], 128, NOR)
    conv4 = conv2d(conv3, conv_dict['w_4'], conv_dict['b_4'], 128, NOR)

    # Fully connected layer  全连接
    # 240/16=15  320/16=20
    w_d = tf.Variable(w_alpha * tf.random_normal([15 * 20 * 128, 1024]))
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))

    dense = tf.reshape(conv4, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    out = tf.add(tf.matmul(dense, conv_dict['out']), conv_dict['out_add'])

    return out

# 读取tfrecrods 数据
def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
       features={
           'label': tf.FixedLenFeature([], tf.int64),
           'img_raw' : tf.FixedLenFeature([], tf.string),
       })
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    # normalize
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)
    return img, label

# 训练
def train_crack_captcha_cnn():

    output = crack_captcha_cnn()
    # softmax  ,sigmoid  第一个是用于单结果, 第二个用于多个结果
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=Y))
    #loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(output, Y))
    # optimizer 为了加快训练 learning_rate应该开始大,然后慢慢衰
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
    # Evaluate model
    # 给出pred在 横向维度上的最大值的 index . prd tensor, 1 横向维度 , 返回的是boolen
    correct_pred = tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1))
    # 把boolean 转成 浮点数据 , 求平均值
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    img, label = read_and_decode("anm_pic_train.tfrecords")
    img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=7000,min_after_dequeue=1000)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # saver.restore(sess, tf.train.latest_checkpoint('/home/root/wtf/yzm/code/'))
        step = 0
#        img, label = read_and_decode("anm_pic_train.tfrecords")
#       img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=64, capacity=70000,min_after_dequeue=1000)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        while True:
#       for i in range(3000):
            imgs, labs = sess.run([img_batch, label_batch])
#           print (labs)
            one_hot_labs = sess.run(tf.cast(one_hot_n(labs, CHAR_SET_LEN), tf.float32))
            sess.run(optimizer, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: dropout, NOR: 1.})
            if step % 50 == 0:
                acc = sess.run( accuracy, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: 1., NOR: 1.})
                print(step, acc)
                if acc > 0.5:
                    saver.save(sess, "crack_capcha.model", global_step=step)
                    print("Complete!!")
                    coord.request_stop()
                    coord.join(threads)
                    sess.close()
                    break
            step += 1
#       print("Complete!!")
#       coord.request_stop()
#        coord.join(threads)
#        sess.close()

train_crack_captcha_cnn()

 

 1.batch_noraml 是为了防止梯度弥散的,但是 到底是不是放在激活之前还不清楚,而且怎么做if 语句眨眼....   

 2.这个代码是用 cpu跑的, 别问我为什么用cpu,穷。 有条件的最好用gpu跑, 省心啊

跑了半天的结果:



 

推荐博客地址:

http://blog.topspeedsnail.com

https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-13-BN/

猜你喜欢

转载自j-sun.iteye.com/blog/2361156