【T-Tensorflow框架学习】Tensorflow简单逻辑回归实现

版权声明:转载请声名出处,谢谢 https://blog.csdn.net/u010591976/article/details/82215994

Softmax回归介绍

我们知道MNIST的每一张图片都表示一个数字,从0到9。我们希望得到给定图片代表每个数字的概率。比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是8的概率是5%(因为8和9都有上半部分的小圆),然后给予它代表其他数字的概率更小的值。
线性层的spftmax回归模型识别手写字是一个使用softmax回归(softmax regression)模型的经典案例。softmax模型可以用来给不同的对象分配概率。即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。


温故:
用到到tensnorflowzhong 的函数:

  • tf.nn.softmax()函数建立softmax模型
  • tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    梯度下降优化器求解
  • argmax()函数 求最大值的索引

# tensorflow做逻辑回归,线性层的spftmax回归模型识别手写字

import tensorflow as tf
import numpy as np
import matplotlib.pylab as plt
from tensorflow.examples.tutorials.mnist import input_data

#mnist数据输入
'''
在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,
第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。
在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。
'''
mnist = input_data.read_data_sets('data/', one_hot=True)

#placeholder是一个占位符,None表示此张量的第一个维度可以是任何长度
x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
#定义W维度是[784,10],初始值为0.W 和 b都是确定的,784个像素点,需要784个权重系数
#定义b维度是[10],初始值是0
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#Logistic regression model 逻辑回归模型
#LR是二分类问题,要升级到softmax多分类问题
actv = tf.nn.softmax(tf.matmul(x,W)+ b) #预测值,每个样本输出10个值

#损失函数Cost Fuction -logP P是属于真实样本的概率值(预测值)
##以估计值y和实际值y_data之间的均方误差作为损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))

#optimize学习率
learning_rate = 0.01
#梯度下降优化器求解,训练的过程就是最下化损失函数cost
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#Predition argmax()函数 求最大值的索引
#其中tf.argmax(actv,1) 代表预测值第一行最大数对应的索引值 ,tf.argmax(y,1)真实值对应的索引
#其中tf.argmax(actv,0) 代表预测值第一列最大数对应的索引值 ,tf.argmax(y,0)真实值对应的索引
#预测值actv的索引和label值(真实值)的索引是否一样,pred返回值是True 或者 False
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#accuracy准确率 tf.cast将true和false转换为float类型,true为1和False为0,累加衡量准确率
accr = tf.reduce_mean(tf.cast(pred, "float"))

#Initializer初始化
init = tf.global_variables_initializer()

#所有训练样本迭代次数
train_epochs = 50
#每次迭代的样本数
batch_size = 100
display_step = 5

#Session
sess = tf.Session()
sess.run(init)

#mini_batch Learning
#所有样本进行50次迭代
for epoch in range(train_epochs):
    avg_cost = 0 #初始损失值为0
    num_batch = int(mnist.train.num_examples/batch_size) #取整,计算有多少簇
    #一次迭代,50个样本
    for i in range(num_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # 将数据输入给模型
        sess.run(optm, feed_dict={x: batch_xs, y: batch_ys}) 
        feeds = {x: batch_xs, y: batch_ys}
        # 损失函数值
        avg_cost += sess.run(cost, feed_dict=feeds)/num_batch #把每一小簇最小化的损失值加到一起
    #display
    if epoch % display_step == 0:       #在一个簇中每5个打印一次
        feeds_train = {x: batch_xs, y: batch_ys}
        feeds_test = {x: mnist.test.images, y: mnist.test.labels}
        train_acc = sess.run(accr, feed_dict=feeds_train)
        test_acc = sess.run(accr, feed_dict=feeds_test)
        print("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f" % (epoch, train_epochs, avg_cost, train_acc, test_acc))

print('Done')

output:

F:\Anaconda\python.exe D:/PycharmProjects/tensorflow逻辑回归实现.py
WARNING:tensorflow:From D:/PycharmProjects/tensorflow逻辑回归实现.py:15: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
Extracting data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From F:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2018-08-30 15:04:53.025707: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epoch: 000/050 cost: 1.177109451 train_acc: 0.870 test_acc: 0.852
Epoch: 005/050 cost: 0.440929337 train_acc: 0.850 test_acc: 0.895
Epoch: 010/050 cost: 0.383323831 train_acc: 0.880 test_acc: 0.905
Epoch: 015/050 cost: 0.357286127 train_acc: 0.860 test_acc: 0.909
Epoch: 020/050 cost: 0.341520180 train_acc: 0.930 test_acc: 0.913
Epoch: 025/050 cost: 0.330535250 train_acc: 0.900 test_acc: 0.914
Epoch: 030/050 cost: 0.322328042 train_acc: 0.880 test_acc: 0.915
Epoch: 035/050 cost: 0.315957263 train_acc: 0.880 test_acc: 0.917
Epoch: 040/050 cost: 0.310728670 train_acc: 0.870 test_acc: 0.918
Epoch: 045/050 cost: 0.306382290 train_acc: 0.970 test_acc: 0.919
Done

Process finished with exit code 0

参考博客:
http://tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
https://blog.csdn.net/SA14023053/article/details/51884894

猜你喜欢

转载自blog.csdn.net/u010591976/article/details/82215994