tensorflow tutorials(三):用tensorflow建立逻辑回归模型

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data


def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))


def model(X, W):
    return tf.matmul(X, W) # notice we use the same model as linear regression, this is because there is a baked in cost function which performs softmax and cross entropy


mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
train_X, train_Y, test_X, test_Y = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

X = tf.placeholder("float", [None, 784]) # create symbolic variables
Y = tf.placeholder("float", [None, 10])

W = init_weights([784, 10]) # like in linear regression, we need a shared variable weight matrix for logistic regression

py_x = model(X, W)

# defined the cost function, compute mean cross entropy (softmax is applied internally)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y)) 
# construct optimizer
train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost) 

# Launch the graph in a session
with tf.Session() as sess:
    # you need to initialize all variables
    tf.initialize_all_variables().run()

    for i in range(100):
        for start, end in zip(range(0, len(train_X), 128), range(128, len(train_X)+1, 128)):
            sess.run(train_op, feed_dict={X: train_X[start:end], Y: train_Y[start:end]})
        print(i, np.mean(np.argmax(test_Y, axis=1) ==
                         sess.run(tf.argmax(py_x, 1), feed_dict={X: test_X})))
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
(0, 0.8841)
(1, 0.89680000000000004)
(2, 0.90310000000000001)
(3, 0.90739999999999998)
(4, 0.90939999999999999)
(5, 0.91090000000000004)
(6, 0.91210000000000002)
(7, 0.91310000000000002)
(8, 0.91490000000000005)
(9, 0.91569999999999996)
(10, 0.91590000000000005)
(11, 0.91700000000000004)
(12, 0.91720000000000002)
(13, 0.91739999999999999)
(14, 0.91769999999999996)
(15, 0.91800000000000004)
(16, 0.91849999999999998)
(17, 0.91910000000000003)
(18, 0.91959999999999997)
(19, 0.91990000000000005)
(20, 0.91979999999999995)
(21, 0.91990000000000005)
(22, 0.92030000000000001)
(23, 0.92030000000000001)
(24, 0.9204)
(25, 0.92110000000000003)
(26, 0.92090000000000005)
(27, 0.92120000000000002)
(28, 0.92130000000000001)
(29, 0.92159999999999997)
(30, 0.92179999999999995)
(31, 0.92200000000000004)
(32, 0.92179999999999995)
(33, 0.92159999999999997)
(34, 0.92149999999999999)
(35, 0.92159999999999997)
(36, 0.92149999999999999)
(37, 0.9214)
(38, 0.92200000000000004)
(39, 0.92200000000000004)
(40, 0.92220000000000002)
(41, 0.92200000000000004)
(42, 0.92190000000000005)
(43, 0.92200000000000004)
(44, 0.92190000000000005)
(45, 0.92179999999999995)
(46, 0.92200000000000004)
(47, 0.92200000000000004)
(48, 0.92220000000000002)
(49, 0.92220000000000002)
(50, 0.92200000000000004)
(51, 0.92220000000000002)
(52, 0.92230000000000001)
(53, 0.92220000000000002)
(54, 0.92220000000000002)
(55, 0.9224)
(56, 0.92249999999999999)
(57, 0.92279999999999995)
(58, 0.92290000000000005)
(59, 0.92300000000000004)
(60, 0.92310000000000003)
(61, 0.9234)
(62, 0.9234)
(63, 0.9234)
(64, 0.92369999999999997)
(65, 0.92359999999999998)
(66, 0.92369999999999997)
(67, 0.92369999999999997)
(68, 0.92379999999999995)
(69, 0.92359999999999998)
(70, 0.92359999999999998)
(71, 0.9234)
(72, 0.92349999999999999)
(73, 0.9234)
(74, 0.92369999999999997)
(75, 0.92369999999999997)
(76, 0.92369999999999997)
(77, 0.92359999999999998)
(78, 0.92369999999999997)
(79, 0.92369999999999997)
(80, 0.92369999999999997)
(81, 0.92359999999999998)
(82, 0.92390000000000005)
(83, 0.92379999999999995)
(84, 0.92369999999999997)
(85, 0.92369999999999997)
(86, 0.92379999999999995)
(87, 0.92379999999999995)
(88, 0.92390000000000005)
(89, 0.92390000000000005)
(90, 0.92390000000000005)
(91, 0.92369999999999997)
(92, 0.92359999999999998)
(93, 0.92359999999999998)
(94, 0.92369999999999997)
(95, 0.92369999999999997)
(96, 0.92379999999999995)
(97, 0.92379999999999995)
(98, 0.92379999999999995)
(99, 0.92379999999999995)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013719780/article/details/53784797

猜你喜欢

转载自blog.csdn.net/fdbvm/article/details/80984138
今日推荐