#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time: 2018/6/26
# @Author: xfLi
# LeNet 实现mnist分类
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
INPUT_NODE = 784
OUTPUT_NODE = 10
IMAGE_SIZE = 28
NUM_CHANNELS = 1
NUM_LABELS = 10
CONV1_SIZE = 5
CONV1_DEEP = 32
CONV2_SIZE = 5
CONV2_DEEP = 64
FC_NODE = 512
BATCH_SIZE = 100
MAX_STEP = 3000
LEARNING_RATE = np.exp(0.01)
LEARNING_RATE_DECAY = 0.99
MOVING_DECAY = 0.99
REGULARIZETION_RATE = 0.0001
MODEL_PATH = './lenet_path/'
MODEL_NAME = 'model.ckpt'
def inference(x, train, regularizer): #定义网络
with tf.variable_scope('layer1_conv'):
weights = tf.get_variable('weights', [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP], \
initializer=tf.truncated_normal_initializer(stddev=0.1))
bias = tf.get_variable('bias', [CONV1_DEEP], initializer=tf.constant_initializer(0.))
conv1 = tf.nn.conv2d(x, weights, [1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, bias))
with tf.name_scope('layer2_pool'):
pool1 = tf.nn.max_pool(relu1, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME')
with tf.variable_scope('layer3_conv'):
weights = tf.get_variable('weights', [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP], \
initializer=tf.truncated_normal_initializer(stddev=0.1))
bias = tf.get_variable('bias', [CONV2_DEEP], initializer=tf.constant_initializer(0.))
conv2 = tf.nn.conv2d(pool1, weights, [1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, bias))
with tf.name_scope('layer4_pool'):
pool2 = tf.nn.max_pool(relu2, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME')
pool2_shape = pool2.get_shape().as_list()
node = pool2_shape[1] * pool2_shape[2] * pool2_shape[3]
pool2_reshape = tf.reshape(pool2, [pool2_shape[0], node])
with tf.variable_scope('layer5_fc'):
weights = tf.get_variable('weights', [node, FC_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer:
tf.add_to_collection('loss', regularizer(weights))
bias = tf.get_variable('bias', [FC_NODE], initializer=tf.constant_initializer(0.1))
fc1 = tf.nn.relu(tf.matmul(pool2_reshape, weights) + bias)
if train:
fc1 = tf.nn.dropout(fc1, keep_prob=0.5)
with tf.variable_scope('layer6_fc'):
weights = tf.get_variable('weights', [FC_NODE, OUTPUT_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer:
tf.add_to_collection('loss', regularizer(weights))
bias = tf.get_variable('bias', [OUTPUT_NODE],
initializer=tf.constant_initializer(0.1))
logit = tf.matmul(fc1, weights) + bias
return logit
def train(mnist): #训练
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS])
y_ = tf.placeholder(tf.float32, [BATCH_SIZE, OUTPUT_NODE])
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZETION_RATE)
y = inference(x, train, regularizer)
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(learning_rate=LEARNING_RATE,
global_step=global_step,
decay_steps=mnist.train.num_examples // BATCH_SIZE,
decay_rate=LEARNING_RATE_DECAY)
cross_entroy = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
cross_entroy_mean = tf.reduce_mean(cross_entroy)
losses = cross_entroy_mean + tf.add_n(tf.get_collection('loss'))
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss=losses, global_step=global_step)
ema = tf.train.ExponentialMovingAverage(MOVING_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op('train')
saver = tf.train.Saver()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(MAX_STEP):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
xs_reshape = np.reshape(xs, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
_, loss_value, step = sess.run([train_op, losses, global_step], feed_dict={x: xs_reshape, y_: ys})
if i % 300 == 0:
print('{} step, loss: {}'.format(step, loss_value))
saver.save(sess, os.path.join(MODEL_PATH, MODEL_NAME), global_step)
def main(argv=None):
mnist = input_data.read_data_sets('./mnist_data/', one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
【tensorflow】LeNet
猜你喜欢
转载自blog.csdn.net/qq_30159015/article/details/80817505
今日推荐
周排行