Tensorflow2.0学习笔记------堆叠模型

早有耳闻Tensorflow2.0相比1.x版本是重大飞跃,需要赶紧给自己充充电!

所有相关代码测试均在colab中进行,不会配置的同学可以参见https://blog.csdn.net/hesongzefairy/article/details/105411219

2.0版本集成了keras后,使用tf.keras来搭建网络,完全继承了keras的优势之处

Step1:导入tf.keras(这这个导入在Pycharm会显示黄色,其实不是报错,是Pycharm的bug)

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

Step2:使用堆叠模型tf.keras.Sequential()构建一个四层网络

model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

Step3:设置训练流程

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=[tf.keras.metrics.categorical_accuracy])

Step4:制作数据集

train_x = np.random.random((1000, 100))
train_y = np.random.random((1000, 10))

val_x = np.random.random((200, 100))
val_y = np.random.random((200, 10))

Step5:训练

model.fit(train_x, train_y, epochs=10, batch_size=100,
          validation_data=(val_x, val_y))

Step6:改进数据集的保存方式tf.data并重新训练

dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
dataset = dataset.batch(32)
dataset = dataset.repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
          validation_data=val_dataset, validation_steps=3)

Step7:模型评估与测试(未使用tf.data和使用tf.data分别测试)

test_x = np.random.random((2000, 100))
test_y = np.random.random((2000, 10))
model.evaluate(test_x, test_y, batch_size=32)

test_data = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_data = test_data.batch(32).repeat()
model.evaluate(test_data, steps=30)

# predict
result = model.predict(test_x, batch_size=32)
print(result)

参考资料:

https://zhuanlan.zhihu.com/p/58825020

https://colab.research.google.com/notebooks/intro.ipynb#scrollTo=-Rh3-Vt9Nev9

https://www.tensorflow.org/

发布了80 篇原创文章 · 获赞 184 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/hesongzefairy/article/details/105416119