深度学习之mnist手写数字识别入门

使用tensorflow框架和python,学习实现简单的神经网络,并进行调参,代码如下:
 

#! /usr/bin/python
# -*- coding:utf-8 -*-

"""
a simple mnist classifier

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None
#读取数据
data_dir = './minst_data'
mnist = input_data.read_data_sets(data_dir,one_hot = True)
#神经元数量
h1_nodes = 1400
h2_nodes = 1400
#x占位符
x = tf.placeholder(tf.float32,[None,784])
#权重,偏置初始化  第一层
W1 = tf.Variable(tf.cast(np.random.randn(784,h1_nodes),tf.float32)*np.sqrt(2.0/784))
b1 = tf.Variable(tf.zeros([h1_nodes]))
#权重,偏置初始化  第二层
W2 = tf.Variable(tf.cast(np.random.randn(h1_nodes,h2_nodes),tf.float32)*np.sqrt(2.0/h1_nodes))
b2 = tf.Variable(tf.zeros([h2_nodes]))
#权重,偏置初始化  输出层
W3 = tf.Variable(tf.zeros([h2_nodes, 10])) 
b3 = tf.Variable(tf.zeros([10]))
#激活函数
h1 = tf.nn.relu(tf.matmul(x,W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1,W2) + b2)
#未激活的logits
y = tf.matmul(h2,W3) + b3

#ground truth 占位符
y_ = tf.placeholder(tf.float32,[None,10])

#交叉熵
cross_entroy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(label = y_,logits = y))

#正则项
#系数
Lambda = 0.0004
#l2正则
regularizer = tf.contrib.layers.l2_regularizer(Lambda)
regularization = regularizer(W1) + regularizer(W2) + regularizer(W13)
#loss
loss = cross_entroy + regularization

#训练step
learning_rate = 1
training_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

#session
sess = tf.Session()
#初始化
init_op = tf.global_variables_initializer()
#run
sess.run(init_op)

#batch_size
batch_size = 100
#batchs
batchs = mnist.train.num_examples//batch_size
#正确预测
correct_prediction = tf.reduce_mean(tf.argmax(y,1),tf.argmax(y_,1))
#正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#training
enpoch_num = 20
for epoch in range(enpoch_num):
  for batch in range(batchs):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
	sess.run(training_step,feed_dict = {x:batch_xs,y_:batch_ys})
  loss_out = sess.run(loss,feed_dict = {x:batch_xs,y_:batch_ys})
  train_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels})
  test_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels})  
  print("epoch"+str(epoch)+"--train_accuracy:"+str(train_accuracy)+"--test_accuracy:"+str(test_accuracy)+"--loss:"+str(loss))
  

猜你喜欢

转载自blog.csdn.net/weixin_41694971/article/details/81331504