基于Mnist数据集的单层神经元识别图像

Mnist识别模糊手写数字

一,导入mnist数据集

简介mnist数据集(内含网盘数据集):https://blog.csdn.net/RObot_123/article/details/103220099

手动下载网址(官网):http://yann.lecun.com/exdb/mnist/
在这里插入图片描述

1.利用tensorflow下载mnist数据集

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

上面代码能自动下载mnist数据集到代码目录的“MNIST_data”文件夹下

2.查看数据集里的内容

print ('输入数据打印:',mnist.train.images)
print ('输入数据打印shape:',mnist.train.images.shape)

import pylab 
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()


print ('输入数据打印shape:',mnist.test.images.shape)
print ('输入数据打印shape:',mnist.validation.images.shape)

输出信息如下:
在这里插入图片描述

序号 内容
1 解压数据集
2 打印解压的图片信息
3 打印图片shape
4 显示训练集中的图-序号1
5 打印测试数据集与验证数据shape

有关shape(形状)的介绍:https://blog.csdn.net/RObot_123/article/details/103102627

二,分析mnist样本特点定义变量

因为 输入的图片是55000×784个矩阵
所以 创建一个**[None,784]的占位符x和一个[None,10]的占位符y**
最后 用feed机制将图片和标签输入进去

import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab 

tf.reset_default_graph()
# 定义占位符
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度长度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 种类别

三,构建模型

1.定义学习参数

  • 定义权重变量W
  • 定义偏值变量b
# 定义学习参数
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

2.定义输出节点

  • softmax分类
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类

3.定义反向传播的结构

  • 损失函数:交叉熵函数
  • 设置学习率:0.01
  • 优化器:GradientDescentOptimizer(梯度下降算法)
# 损失函数
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

四,训练模型并输出中间状态参数

  • 训练次数(迭代次数):25
  • 设置批次量:100
  • 显示步长:1
  • 启用Session进行运算处理
training_epochs = 25
batch_size = 100
display_step = 1
#saver = tf.train.Saver()
#model_path = "log/521model.ckpt"

# 启动session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 显示训练中的详细信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

输出信息:
在这里插入图片描述

五,测试模型

  • 输出(pred)与标签(y)进行比较
  • reduce_mean对corrcet_prediction求平均值
    # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

模型准确率:
在这里插入图片描述

六,保存模型

  • 建议saver和路径
  • 保存模型
saver = tf.train.Saver()
model_path = "log/mnisat_model.ckpt"

调用 saver

  	# Save model weights to disk
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)

输出信息:
在这里插入图片描述
实际保存状况:
在这里插入图片描述

七,读取模型

首先注释掉session会话后的代码,然后将如下代码添加到session里去

#读取模型
print("Starting 2nd session...")
with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_path)
    
     # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
    
    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
    print(outputval,predv,batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show()
    
    im = batch_xs[1]
    im = im.reshape(-1,28)
    pylab.imshow(im)
    pylab.show() 

在这里插入图片描述

八,完整代码

1.验证数据集(简略)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print ('输入数据打印:',mnist.train.images)
print ('输入数据打印shape:',mnist.train.images.shape)

import pylab 
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print ('输入数据打印shape:',mnist.test.images.shape)
print ('输入数据打印shape:',mnist.validation.images.shape)

2.验证数据集2(较详细)

#导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data #从网上下载mnist数据集的模块
mnist = input_data.read_data_sets('MNIST_data/',one_hot = False) #从指定文件夹导入数据集的数据
#分析mnist数据集
print('输入训练数据集数据:',mnist.train.images) #打引导如数据集的数据
print('输入训练数据集shape:',mnist.train.images.shape) #打印训练数据集的形状
print('输入测试数据集shape:',mnist.test.images.shape) #用于评估训练过程中的准确度
print('输入验证数据集shape:',mnist.validation.images.shape) #用于评估最终模型的准确度
print('输入标签的shape:',mnist.train.labels.shape)
#展示mnist数据集
import pylab 
im = mnist.test.images[6] #train中的第六张图
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

3.识别数据集模糊手写数字

import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab 

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 classes

# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 构建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类

# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/mnist_model.ckpt"

# 启动session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 显示训练中的详细信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print( " Finished!")

    # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # Save model weights to disk
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)



##读取模型
#print("Starting 2nd session...")
#with tf.Session() as sess:
#    # Initialize variables
#    sess.run(tf.global_variables_initializer())
#    # Restore model weights from previously saved model
#    saver.restore(sess, model_path)
#    
#     # 测试 model
#    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
#    # 计算准确率
#    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
#    
#    output = tf.argmax(pred, 1)
#    batch_xs, batch_ys = mnist.train.next_batch(2)
#    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
#    print(outputval,predv,batch_ys)
#
#    im = batch_xs[0]
#    im = im.reshape(-1,28)
#    pylab.imshow(im)
#    pylab.show()
#    
#    im = batch_xs[1]
#    im = im.reshape(-1,28)
#    pylab.imshow(im)
#    pylab.show()



上文若有任何错误或不妥欢迎指出,谢谢!

发布了34 篇原创文章 · 获赞 9 · 访问量 3027

猜你喜欢

转载自blog.csdn.net/RObot_123/article/details/103218114