DL4-基于tensorflow的LR多分类任务模型实现

版权声明:此文章有作者原创,涉及相关版本问题可以联系作者,[email protected] https://blog.csdn.net/weixin_42600072/article/details/89055761
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./data/MNIST/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg    = mnist.test.images
testlabel  = mnist.test.labels
print (trainimg.shape)
print (trainlabel.shape)
print (testimg.shape)
print (testlabel.shape)
print (trainimg)
print (trainlabel[0])
(55000, 784)
(55000, 10)
(10000, 784)
(10000, 10)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

逻辑回归框架部分

参数初始化

x = tf.placeholder('float', [None, 784]) #None表示数量未知或者无限
y = tf.placeholder('float', [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

逻辑回归模型:Softmax model

actv = tf.nn.softmax(tf.matmul(x, W) + b)

代价函数

cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(actv), reduction_indices=1))

优化器

learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

预测

pred = tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))

精度

accr = tf.reduce_mean(tf.cast(pred, 'float'))

初始化

init = tf.global_variables_initializer()

函数说明

sess = tf.InteractiveSession()

arr = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])
#tf.rank(arr).eval()
#tf.shape(arr).eval()
tf.argmax(arr, 0).eval() #参数0表示按列求最大值,返回其索引值
array([0, 3, 2, 4, 0, 1], dtype=int64)

迭代求解过程

training_epochs = 200
batch_size      = 100
display_step    = 10

sess = tf.Session()
sess.run(init)

# MINI-BATCH LEARNING
for epoch in range(training_epochs):
    avg_cost = 0.
    num_batch = int(mnist.train.num_examples / batch_size)
    for i in range(num_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        sess.run(optm, feed_dict = {x: batch_xs, y:batch_ys})
        feeds = {x: batch_xs, y: batch_ys}
        avg_cost += sess.run(cost, feed_dict= feeds)/num_batch
    if epoch % display_step == 0:
        feeds_train = {x: batch_xs, y: batch_ys}
        feeds_test = {x: mnist.test.images, y: mnist.test.labels}
        train_acc = sess.run(accr, feed_dict=feeds_train)
        test_acc = sess.run(accr, feed_dict=feeds_test)
        print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f" 
               % (epoch, training_epochs, avg_cost, train_acc, test_acc))
print ("DONE")
Epoch: 000/200 cost: 1.176300201 train_acc: 0.890 test_acc: 0.855
Epoch: 010/200 cost: 0.383338681 train_acc: 0.920 test_acc: 0.905
Epoch: 020/200 cost: 0.341450560 train_acc: 0.930 test_acc: 0.912
Epoch: 030/200 cost: 0.322305107 train_acc: 0.950 test_acc: 0.915
Epoch: 040/200 cost: 0.310741583 train_acc: 0.870 test_acc: 0.918
Epoch: 050/200 cost: 0.302640686 train_acc: 0.890 test_acc: 0.919
Epoch: 060/200 cost: 0.296618789 train_acc: 0.940 test_acc: 0.919
Epoch: 070/200 cost: 0.291851509 train_acc: 0.860 test_acc: 0.920
Epoch: 080/200 cost: 0.287941018 train_acc: 0.930 test_acc: 0.921
Epoch: 090/200 cost: 0.284705462 train_acc: 0.930 test_acc: 0.921
Epoch: 100/200 cost: 0.281881044 train_acc: 0.920 test_acc: 0.922
Epoch: 110/200 cost: 0.279476521 train_acc: 0.950 test_acc: 0.922
Epoch: 120/200 cost: 0.277341484 train_acc: 0.900 test_acc: 0.922
Epoch: 130/200 cost: 0.275421649 train_acc: 0.920 test_acc: 0.923
Epoch: 140/200 cost: 0.273702574 train_acc: 0.910 test_acc: 0.922
Epoch: 150/200 cost: 0.272169173 train_acc: 0.900 test_acc: 0.923
Epoch: 160/200 cost: 0.270728730 train_acc: 0.920 test_acc: 0.923
Epoch: 170/200 cost: 0.269369089 train_acc: 0.880 test_acc: 0.923
Epoch: 180/200 cost: 0.268188662 train_acc: 0.930 test_acc: 0.923
Epoch: 190/200 cost: 0.267080625 train_acc: 0.930 test_acc: 0.923
DONE

猜你喜欢

转载自blog.csdn.net/weixin_42600072/article/details/89055761