Tensorflow学习二_mnist入门

前两天刚刚装好我的Tensorflow,于是今天通过tensorflow的中文网站(http://www.tensorfly.cn/tfdoc/get_started/introduction.html),

准备开始学习关于tensorflow的入门——mnist手写字母的识别入门。

在此主要记录一下我运行我的第一个代码时,出现的小错误。

一、简单示例

在简介中有一段使用 Python API 撰写的 TensorFlow 示例代码,

直接拿来运行:出现错误

  File "D:/tensorflow/python文件/tensorflow1.py", line 37
    print step, sess.run(W), sess.run(b)
             ^
SyntaxError: invalid syntax

后来通过网上搜索发现,是因为在官网上所用的代码是python2.x,而我使用的是python3,

1、改xrange为range

2、修改print格式

运行成功

import tensorflow as tf
import numpy as np

# 使用 NumPy 生成假数据(phony data), 总共 100 个点.
x_data = np.float32(np.random.rand(2, 100)) # 随机输入
y_data = np.dot([0.100, 0.200], x_data) + 0.300

# 构造一个线性模型
# 
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y = tf.matmul(W, x_data) + b

# 最小化方差
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# 初始化变量
init = tf.initialize_all_variables()

# 启动图 (graph)
sess = tf.Session()
sess.run(init)

# 拟合平面
for step in range(0, 201):
    sess.run(train)
    if step % 20 == 0:
        print (step, sess.run(W), sess.run(b))
View Code

二、mnist入门

在下载数据集的时候,官网上提供了两种方法:一是下载代码并导入到项目,二是直接用python源代码自动下载和安装。

在这里,我是直接用python源代码下载和安装。

#导入数据集
import input_data 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#实现回归模型
import tensorflow as tf
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

#训练模型
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

#评估模型
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

 报错:

ImportError: No module named 'input_data'

将import input_data 代码换成  from tensorflow.examples.tutorials.mnist import input_data

运行成功:

运行结果的正确率是92%

猜你喜欢

转载自www.cnblogs.com/smile321/p/11205527.html
今日推荐