Tensorflow搭建第一个分类学习神经网络(Classification)

版权声明:站在巨人的肩膀上学习。 https://blog.csdn.net/zgcr654321/article/details/82958075

我们用一个简单的实例来说明分类学习。

我们使用mnist数据集,用下面的分类学习神经网络来学习这个数据库,并计算对手写体数字图片识别的准确率。

首先我们要下载一个数据库。

我们使用MNIST库,这是一个手写体数字库,总共有 60000 张图片,其中 50000 张训练图片,10000 张测试图片。

输出层的激励函数我们使用softmax激励函数,这是一个常用在分类问题上的激励函数。

loss函数我们使用交叉熵函数(cross_entropy)。

训练函数仍使用梯度下降法,交叉熵函数作为训练指引方向。

由于训练集中的图片很多,我们只取100张图片来训练。每训练20次,我们打印一次准确度。准确度即我们搭建的模型的预测值和真实值是否相同的概率,用百分比表示。

代码如下:

import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 准备数据库(MNIST库,这是一个手写体数据库)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# 这一步是如果电脑上没有这个数据库则会下载这个数据库到.py文件所在目录并创建一个MNIST_data文件夹
# 注意运行时这个数据库可能要翻墙才能下载下来

def add_layer(inputs, in_size, out_size, activation_function=None):
	Weights = tf.Variable(tf.random_normal([in_size, out_size]))
	biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
	Wx_plus_b = tf.matmul(inputs, Weights) + biases
	if activation_function is None:
		outputs = Wx_plus_b
	# activation_function is None时没有激励函数,是线性关系
	else:
		outputs = activation_function(Wx_plus_b)
	# activation_function不为None时,得到的Wx_plus_b再传入activation_function再处理一下
	return outputs


def compute_accuracy(v_xs, v_ys):
	global prediction
	# 使用global则对全局变量prediction进行操作
	y_pre = sess.run(prediction, feed_dict={xs: v_xs})
	# 使用xs输入数据生成预测值prediction
	correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
	# 对于预测值和真实值的差别
	accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
	# 计算预测的准确率
	result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
	# run得到结果,结果是个百分比
	return result


xs = tf.placeholder(tf.float32, [None, 784])
# None表示不规定样本的数量,784表示每个样本的大小为28X28=784个像素点
ys = tf.placeholder(tf.float32, [None, 10])
# 每张图片表示一个数字,我们的输出是数字0到9,所以是10个输出

prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)
# 调用add_layer定义输出层,输入数据是784个特征,输出数据是10个特征,激励采用softmax函数
# softmax激励函数一般用于classification

# 搭建分类模型时,loss函数(即最优化目标函数)选用交叉熵函数(cross_entropy)
# 交叉熵用来衡量预测值和真实值的相似程度,如果完全相同,它们的交叉熵等于零。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))
# 定义训练函数,使用梯度下降法训练,0.5是学习效率,通常小于1,minimize(cross_entropy)指要将cross_entropy减小
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 创建会话,并开始将网络初始化
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(1000):
	batch_xs, batch_ys = mnist.train.next_batch(100)
	# 训练时只从数据集中取100张图片来训练
	sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
	if i % 50 == 0:
		# 每训练50次打印准确度
		print(compute_accuracy(mnist.test.images, mnist.test.labels))
# 对比mnist中的training data和testing data的准确度

运行结果如下:

WARNING:tensorflow:From C:/Users/1234/Desktop/test/src/test.py:9: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
0.1657
0.4756
0.6107
0.6827
0.7226
0.7455
0.7492
0.7783
0.7833
0.8
0.8049
0.8107
0.8198
0.8172
0.8242
0.8309
0.8324
0.8384
0.8416
0.8422
0.8428
0.846
0.8481
0.8528
0.8513
0.8523
0.8539
0.8557
0.8597
0.8627
0.8585
0.861
0.8628
0.8667
0.8654
0.8664
0.8687
0.8659
0.8686
0.8686
0.8714
0.8701
0.8727
0.8714
0.8753
0.8742
0.8713
0.8777
0.874
0.8754

Process finished with exit code 0

前面的warning提示一般不影响结果,可忽略。不过有时提示会告诉你有个别包无法下载,这时请更换新的梯子再重新运行。

可以看到我们训练的准确度是在不断提高的。

猜你喜欢

转载自blog.csdn.net/zgcr654321/article/details/82958075