TensorFlow2学习六之构建神经网络

一、搭建神经网络模型

第一步 导入相关模块
第二步 准备测试集(x_train,y_train)和训练集(x_test,y_test)
第三步 搭建网络结构model = tf.keras.models.Sequential
第四步 配置训练方法model.compile
第五步 执行训练过程model.fit
第六步 model.summary打印出网络的结构和参数统计

tf.keras.models.Sequential

model = tf.keras.models.Sequential ([ 网络结构 ]) #描述各层网络

网络结构:

  1. 拉直层: tf.keras.layers.Flatten( )
  2. 全连接层: tf.keras.layers.Dense(神经元个数, activation= "激活函数“ ,kernel_regularizer=正则化)

activation可选: relu、 softmax、 sigmoid 、 tanh
kernel_regularizer可选: tf.keras.regularizers.l1()、tf.keras.regularizers.l2()

  1. 卷积层: tf.keras.layers.Conv2D(filters = 卷积核个数, kernel_size = 卷积核尺寸,strides = 卷积步长, padding = " valid" or “same”)
  2. LSTM层: tf.keras.layers.LSTM()

compile()配置神经网络的训练方法

model.compile(optimizer = 优化器, 
				loss = 损失函数, 
				metrics = [“准确率”] )

优化器(Optimizer)

‘sgd’ or tf.keras.optimizers.SGD (lr=学习率,momentum=动量参数)
‘adagrad’ or tf.keras.optimizers.Adagrad (lr=学习率)
‘adadelta’ or tf.keras.optimizers.Adadelta (lr=学习率)
‘adam’ or tf.keras.optimizers.Adam (lr=学习率, beta_1=0.9, beta_2=0.999)

损失函数(loss)

‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

评测指标(Metrics)

‘accuracy’ :y_和y都是数值,如y_=[1] y=[1]
‘categorical_accuracy’ :y_和y都是独热码(概率分布),如y_=[0,1,0] y=[0.256,0.695,0.048]
‘sparse_categorical_accuracy’ :y_是数值,y是独热码(概率分布),如y_=[1] y=[0.256,0.695,0.048]

fit()执行训练过程

model.fit (训练集的输入特征, 训练集的标签, 
			batch_size= , epochs= , 
			validation_data=(测试集的输入特征,测试集的标签),
			validation_split=从训练集划分多少比例给测试集,
			validation_freq = 多少次epoch测试一次)

summary()打印出网络的结构和参数统计

在这里插入图片描述

复现鸢尾花分类

import tensorflow as tf
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

# 实现数据集的乱序
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

# 搭建网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3,                                               # 神经元个数
                          activation='softmax',                            # 选用softmax激活函数
                          kernel_regularizer=tf.keras.regularizers.l2()    # 正则化方法
                          )
])
# 配置训练方法
# 选择SparseCategoricalCrossentropy损失函数,由于激活函数为softmax函数,使得输出为概率分布函数而不是原始输出,所以from_logits是False
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),   # 选择SGD优化器,学习率为0.1
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])  # 评测指标
# 执行训练过程。validation_split告知选择20%的数据作为测试集,validation_freq告知每迭代20次训练集要在测试集中验证一次准确率。
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

结果:
在这里插入图片描述

二、class搭建神经网络模型

Sequential可以搭建出上层输出就是下层输入的顺序网络结构,无法搭建非顺序网络结构。拥有跳连的非顺序网络结构,可以用类class封装神经网络结构。

__init__( ) 定义所需网络结构块
call( ) 调用__init__()中搭建好的积木,写出前向传播。
class MyModel(Model):
	def __init__(self):
		super(MyModel, self).__init__()
		定义网络结构块
	def call(self, x):
		调用网络结构块,实现前向传播
		return y
model = MyModel()

类模块搭建鸢尾花分类

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

# 在__init__()函数中定义了要在call函数中调用的具有三个神经元的全连接网络Dense
# call()函数中调用了的d1实现从输入x到输出y的前向传播
class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, x):
        y = self.d1(x)
        return y

model = IrisModel()

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_41754907/article/details/113031849