Tensorflow2.0 custom training


Preface

The encapsulation of tf.keras in Tensorflow2.0 is really good. After defining the model, you can directly use the model.fit function to complete the training. However, sometimes it is necessary to perform some special processing on the data, or in the case of multiple inputs and multiple outputs, or also If you need to customize loss and accuracy, you need to use custom training.

The following is the text of this article. The following cases are for reference.

1. Automatic differential operation tf.GradientTape

import tensorflow as tf

w=tf.Variable([[1.0]])
with tf.GradientTape() as t:
    loss=w*w
grad=t.gradient(loss,w)#求loss对w的导数
grad

Running result:
Please add image description
loss is equal to the square of w. The derivative of loss with respect to w is 2w, w=1, so the derivative of loss with respect to w is 2

You can also take derivatives of constants

w=tf.constant(3.0)
with tf.GradientTape() as t:
    t.watch(w)
    loss=w*w
t.gradient(loss,w)

Run result:
Please add image description
The watch function in the example adds the variable w that needs to calculate the gradient.

By default, the resources of GradientTape are released after calling the gradient function and cannot be called again. If you want to call it multiple times, set the persistent parameter to True.

w=tf.constant(3.0)
with tf.GradientTape(persistent=True) as t:
    t.watch(w)
    y=w*w
    z=y*y
t.gradient(y,w),t.gradient(z,w)

Operating results:
Please add image description
The derivative of y with respect to w is 2w, and the derivative of z with w is 4w^3. The derivative value can be obtained by bringing in 6, 108

2. tf.keras.metrics module

1.tf.keras.metrics.Mean

is used to find the mean. Each time it is called, the mean of all previously called numbers is calculated
Example:

m=tf.keras.metrics.Mean('acc')
m(10)

Please add image description
The current average is 10

m(20)

Please add image description
The current average is 15 (10+20)/2

m([30,40])

Please add image description
You can also pass a list directly
The current average is 25 (10+20+30+40)/4

When executed in the code, the result will not be printed every time. Only the final result can be called m.result()

m.result()

Please add image description
To reset the values ​​inside, use the reset_states method

m.reset_states()
m.result()

Please add image description
reset to 0

2.tf.keras.metrics.SparseCategoricalAccuracy

The method of calculating the accuracy encapsulated by Tensorflow will also average the previously calculated accuracy for each call. It is similar to tf.keras.metrics.Mean. The input parameters are different, and the average accuracy of the calculation is

The code is as follows (example):

a=tf.keras.metrics.SparseCategoricalAccuracy('acc')
labels=[0,1,2,3]
pred=[[0.8,0.05,0.05,0.1],
     [0.05,0.8,0.05,0.1],
     [0.15,0.15,0.65,0.05],
     [0.05,0.65,0.2,0.1]]

For example, there are four pieces of data, the label is one-dimensional, and the output value is a two-dimensional 4*4 matrix. The first 4 represents four pieces of data, and the second four represents four categories. The four-classification problem model finally outputs four through the softmax activation function. Probability distribution value, the index corresponding to the maximum probability value is the category of discrimination

To view the predicted category, call tf.argmax

tf.argmax(pred,axis=1)

Please add image description
It can be seen that the predicted category is 0 1 2 1 and the true category is 0 1 2 3

Calculate accuracy

a(labels,pred)

Please add image description

The correct rate is 0.75
Note: This method is not to calculate the correct rate, but to calculate the average of the correct rates obtained previously. This example If it is called for the first time, the average accuracy rate is the current accuracy rate. If it is called again, the accuracy rate is 1, and the output is (0.75+1)/2=0.875


3. Customized training practice, handwritten digit recognition

import tensorflow as tf
(train_image,train_lables),(test_image,test_labels)=tf.keras.datasets.mnist.load_data()#载入数据,若第一次调用则需要等待一小会下载

#都除以255将数据从0-255变为0-1,并指定其数据类型
train_image=tf.cast(train_image/255,tf.float32)
test_image=tf.cast(test_image/255,tf.float32)

#指定标签类型为int64
train_labels=tf.cast(train_lables,tf.int64)
test_labels=tf.cast(test_labels,tf.int64)

#封装为dataset
train_dataset=tf.data.Dataset.from_tensor_slices(
    (train_image,train_labels)
)
test_dataset=tf.data.Dataset.from_tensor_slices(
    (test_image,test_labels)
)

#将训练数据打乱并设置batch批次
train_dataset=train_dataset.shuffle(60000).batch(64)
test_dataset=test_dataset.batch(64)

#建立模型
model=tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(512,activation='relu'),
    tf.keras.layers.Dense(256,activation='relu'),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')
])
#优化器
optimizer=tf.keras.optimizers.Adam()
#计算损失的函数,多分类交叉熵,类别是离散值,例如0 1 2 3 
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

#计算训练集与测试集平均损失与正确率
train_loss=tf.keras.metrics.Mean('train_loss')
train_acc=tf.keras.metrics.SparseCategoricalAccuracy('train_acc')

test_loss=tf.keras.metrics.Mean('test_loss')
test_acc=tf.keras.metrics.SparseCategoricalAccuracy('test_acc')

#每一步训练
def train_step(model,images,labels):
    with tf.GradientTape() as t:
        pred=model(images)
        loss_step=loss_fn(labels,pred)
    #求梯度
    grads =t.gradient(loss_step,model.trainable_variables)
    #反向传播
	optimizer.apply_gradients(zip(grads,model.trainable_variables))
	#求损失平均值与正确率平均值
    train_loss(loss_step)
    train_acc(labels,pred)

#每一步测试
def test_step(model,images,labels):
    pred=model(images)
    loss_step=loss_fn(labels,pred)
    
    test_loss(loss_step)
    test_acc(labels,pred)

#训练
def train():
    for epoch in range(5):
        for(images,labels) in train_dataset:
            train_step(model,images,labels)
        for(images,labels) in test_dataset:
            test_step(model,images,labels)
        print("Epoch {} train loss is {},train acc is {} test loss is {} test acc is {}".format(epoch,train_loss.result(),train_acc.result(),
                                                            test_loss.result(),test_acc.result()))
        train_loss.reset_states()
        train_acc.reset_states()
train()

Please add image description

Summarize

This article mainly introduces the custom training of tensorflow, gives several examples of tf.GradientTape and tf.keras.metrics modules, and finally attaches the actual code. I hope it can help everyone.

Guess you like

Origin blog.csdn.net/weixin_44599230/article/details/121046896