解决tf升级warning:Please use alternatives such as official/mnist/dataset.py from tensorflow/models

warning描述

WARNING:tensorflow:From /home/xiaoshumiao/PycharmProjects/tensorflow/main/classification.py:23: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

对应代码

导入数据集

mnist = input_data.read_data_sets('../MNIST_data',one_hot=True)

生成训练集

#  batch_xs,batch_ys = mnist.train.next_batch(100)

得到的数据是100*784格式

生成测试集

mnist.test.images
mnist.test.labels

新的导入方式

使用新的导入数据方式,通过keras导入

(train_x_image, train_y), (test_x_image, test_y) = keras.datasets.mnist.load_data(path='/home/xiaoshumiao/.keras/datasets/mnist.npz')

这里我之前下载过,所以直接调用下载的文件就可以,如果没有下载不用指定,就会下载到~/.kreas中。第二次就可以制定path了。

train:图片–>一维向量&数据标准化

但是通过keras导入的训练数据也就是mnist数据集是图片的形式,并不是之前训练网络用的一维784的向量。当然后面会改进,请关注我的博客,后续更新。

下面的代码就是把二维图片2828变成1784的方法,两种都可以:

train_x = train_x_image.reshape(train_x_image.shape[0], -1) / 255.
x_train=train_x_image.reshape(60000,784).astype('float32')/255.

此处对于数据还进行了标准化,因为标准化可以提高准确率。

label:数字标号–>onehot编码

而对于y标签,通过kreas导入的数据也不是onehot形式,而是对应的类的标号,比如“1”。因此需要将其转变为onthot编码。

train_y = keras.utils.to_categorical(train_y, num_classes=10)

此处的num_classes=10可以不加。

扫描二维码关注公众号,回复: 10045945 查看本文章
发布了21 篇原创文章 · 获赞 5 · 访问量 370

猜你喜欢

转载自blog.csdn.net/def_init_myself/article/details/104860899