从mnist数据集中导入数据,并存在./data/中。
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./data/',one_hot=True)
打印出训练集,检验集,测试集的个数。
print("train data size: " + str(mnist.train.num_examples))
print("validation data size: "+str(mnist.validation.num_examples))
print("test data size: "+str(mnist.test.num_examples))
输出:
train data size: 55000
validation data size: 5000
test data size: 10000
打印出第一个训练集的特征(28*28维)和标签。
print(mnist.train.labels[0])
print(mnist.train.images[0])
标签:
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
本来特征不想放上了,毕竟784个数字,放上来也没什么意义。但是想了想,我转化一下,变成28*28的格子,大家看一下:
for i in range(28):
for j in range(28):
if mnist.train.images[0][j+28*i] > 0.0:
print(" 1",end=" ")
else:
print(" 0",end=" "),
print("")
我觉得是个3,但是标签说这是个7,好吧,大概是给7的中间划了一道。
一个BATCH一个BATCH地读取:
BATCH_SIZE = 200
xs,ys = mnist.train.next_batch(BATCH_SIZE)
print("xs shape: " + str(xs.shape))
print("ys shape: " + str(ys.shape))
xs shape: (200, 784)
ys shape: (200, 10)
一些以后会用到的函数:
tf.get_collection() #从集合中去全部变量,生成一个列表
tf.add_n([]) #列表内对应元素相加
tf.cast(x,dtype) #把x转为dtype类型
tf.argmax(x,axis) #返回最大值所在索引号,如tf.argmax([1,0,0],1)返回0.注意后面的1表示在第一维。
os.path.join("home","name") #返回home/name
#字符串.split() #按指定拆分符对字符串切片,
#返回分割后的列表,如
'./model/mnist_model-1001'.split('/')[-1]
#返回1001
with tf.Fraph().as_default() as g: #其内定义的节点在计算图g中,一般用于复现已经定义好的神经网络。
使用onehot=true,表示数组中只有一个元素的值是1.0,其他元素的值是0.0。
扫描二维码关注公众号,回复:
9959841 查看本文章