Knowledge Distillation examples

根据 BERT and Knowledge Distillation关于知识蒸馏的介绍,我们已经知道了其的定义和基本流程,下面通过例子来具体看一下如何用代码来实现。

首先导入所需的包

from __future__ import absolute_import,unicode_literals,division,print_function

import os
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

print (tf.__version__)

本例子依赖于Tensorflow2.0,数据集为MNIST,数据集的导入可以使用tf.keras.datasets.mnist,也可以使用Tensorflow中的Datasets包tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True),同时做一些转换

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
train_images /= 255
test_images /= 255

print (train_images.shape)

train_images = train_images.reshape((len(train_images), -1))
test_images = test_images.reshape((len(test_images), -1))

train_labels = tf.keras.utils.to_categorical(train_labels.astype('float32'))
test_labels = tf.keras.utils.to_categorical(test_labels.astype('float32'))

为了简便起见,老师网络和学生网络我们只是用tf.keras.layers中的Dense和Dropout。

def build_teacherNet(input_size, class_num, training = False):
    model = tf.keras.Sequential(name='teacherNet')
    model.add(tf.keras.layers.Dense(1200, input_shape=(input_size,), activation='relu'))
    if training == True:
        model.add(tf.keras.layers.Dropout(0.2))
    model.add(tf.keras.layers.Dense(1200, activation='relu'))
    model.add(tf.keras.layers.Dense(class_num, name='logits'))
    model.add(tf.keras.layers.Activation('softmax', name='softmax'))
    
    model.summary()

    return model
 
 """
 Model: "teacherNet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1200)              942000    
_________________________________________________________________
dropout (Dropout)            (None, 1200)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1200)              1441200   
_________________________________________________________________
logits (Dense)               (None, 10)                12010     
_________________________________________________________________
softmax (Activation)         (None, 10)                0         
=================================================================
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0
 """

学生网络

def build_studentNet(input_size, class_num):
    model = tf.keras.Sequential(name='studentNet')
    model.add(tf.keras.layers.Dense(10, input_shape=(input_size,), activation='relu'))
    model.add(tf.keras.layers.Dense(class_num, name='logits'))
    model.add(tf.keras.layers.Activation('softmax', name='softmax'))

    model.summary()
    return model
"""
Model: "studentNet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (None, 10)                7850      
_________________________________________________________________
logits (Dense)               (None, 10)                110       
_________________________________________________________________
softmax (Activation)         (None, 10)                0         
=================================================================
Total params: 7,960
Trainable params: 7,960
Non-trainable params: 0
_________________________________________________________________
"""

首先我们直接训练30个epoch的老师网络,最后的准确率差不多91%左右

# 训练老师网络
opt = tf.keras.optimizers.SGD(learning_rate=0.001)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='teacher', histogram_freq=1)

teacherNet.compile(optimizer=opt, 
              loss=tf.keras.losses.CategoricalCrossentropy(), 
              metrics=['accuracy'])

teacherNet.fit(train_images, train_labels, 
          validation_data=(test_images, test_labels),
          epochs=10, 
          batch_size=100, 
          verbose = 2,
          callbacks=[tensorboard_callback])

test_loss, test_acc = teacherNet.evaluate(test_images,  test_labels)
print(test_acc)

# Save JSON config to disk
json_config = teacherNet.to_json()
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)
    
# Save weights to disk
teacherNet.save_weights('model_weights.h5')

"""
Epoch 30/30
60000/60000 - 11s - loss: 0.3211 - accuracy: 0.9080 - val_loss: 0.2859 - val_accuracy: 0.9190
10000/10000 [==============================] - 1s 138us/sample - loss: 0.2859 - accuracy: 0.9190
"""

然后我们再单独的训练10个epoch的学生网络,最后的准确率只有78%,可以看出简单的模型学习能力是稍差一些

# 单独训练学生网络
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='student', histogram_freq=1)
studentNet.compile(optimizer=opt, 
                   loss=tf.keras.losses.CategoricalCrossentropy(), 
                   metrics=['accuracy'])
studentNet.fit(train_images, train_labels, 
               validation_data=(test_images, test_labels), 
               epochs=10, 
               batch_size=100,
               verbose=2,
               callbacks=[tensorboard_callback])

"""
Epoch 10/10
60000/60000 - 1s - loss: 0.9100 - accuracy: 0.7811 - val_loss: 0.8628 - val_accuracy: 0.7915
"""

下面通过知识蒸馏的方式使用老师网络的输出做为指导信息来帮助学生网络 进行训练,首先定义加入温度参数 T T 的Softmax函数和蒸馏过程中对应的损失函数,损失项包含两部分 L o s s h a r d Loss_{hard} L o s s s o f t Loss_{soft}

def softmax_with_temp(logits, temp=1):
    logits = (logits - tf.math.reduce_max(logits)) / temp
    exp_logits = tf.math.exp(logits)
    logits_sum = tf.math.reduce_sum(exp_logits, axis=-1, keepdims=True)
    result = exp_logits / logits_sum
    return result

def custom_ce(y_true, y_soft, y_pred, y_soft_pred, alpha = 0.5):
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    origin_loss = -tf.math.reduce_mean(tf.math.reduce_sum(y_true * tf.math.log(y_pred), axis=-1, keepdims=False))

    y_soft = tf.clip_by_value(y_soft, 1e-7, 1 - 1e-7)
    y_soft_pred = tf.clip_by_value(y_soft_pred, 1e-7, 1 - 1e-7)
    soft_loss = -tf.math.reduce_mean(tf.math.reduce_sum(y_soft * tf.math.log(y_soft_pred), axis=-1, keepdims=False))
    
    return alpha * soft_loss + (1 - alpha) * origin_loss

为了获取老师网络的softmax输出,首先通过文件重新加载训练好的老师网络

# 加载训练好的老师模型
with open('model_config.json') as json_file:
    json_config = json_file.read()
teacher_model = tf.keras.models.model_from_json(json_config)
teacher_model.load_weights('model_weights.h5')
teacher_model.summary()

### remove softmax in Teacher model
teacher_model_ex_softmax = tf.keras.Model(inputs=teacher_model.input, outputs=teacher_model.get_layer('logits').output)

配置评估项和损失项

logits = studentNet.get_layer('logits').output

# Define our metrics
train_acc = tf.keras.metrics.CategoricalAccuracy('train_accuracy')
test_acc = tf.keras.metrics.CategoricalAccuracy('test_accuracy')

train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

单步训练

扫描二维码关注公众号,回复: 9842476 查看本文章
def train_step(images, labels):
    with tf.GradientTape() as tape:
    	# 获取学生网络的输出
        pred = studentNet(images, training=True)
        
		# 获取经过softmax的输出和改造后的softmax的输出
        unsoft_pred = softmax_with_temp(pred, 1)
        soft_pred = softmax_with_temp(pred, temp)
		
		# 获取老师网络不经softmax的输出
        teacher_logits = teacher_model_ex_softmax(images)
        # 获取老师网络经过改造后softmax的输出
        softened_teacher_prob = softmax_with_temp(teacher_logits, temp)
		# 获取蒸馏损失	
        loss_value = custom_ce(labels, softened_teacher_prob, unsoft_pred, soft_pred)
	
	# 更新变量,梯度更新
    grads = tape.gradient(loss_value, studentNet.trainable_variables)
    opt.apply_gradients(zip(grads, studentNet.trainable_variables))

    train_acc(labels, pred)
    train_loss(loss_value)
    
    return loss_value

训练

step = 0
ckpt_step = 0
ckpt_step = tf.cast(ckpt_step, tf.int64)

for x, y in dataset:
    loss = train_step(x, y)

    test_acc(test_labels, studentNet(test_images, training=False))

    step += 1

    if step % 1000 == 0:
        print("step: {} -- total steps is: {} -- train loss: {} -- train accuracy: {} -- test accuracy: {} ".format(step, int(len(train_images) * epochs / batch_sz), train_loss.result(), train_acc.result(), test_acc.result()))

"""
step: 1000 -- total steps is: 6000 -- train loss: 1.3406617641448975 -- train accuracy: 0.7934600114822388 -- test accuracy: 0.7982499003410339
step: 2000 -- total steps is: 6000 -- train loss: 1.3140987157821655 -- train accuracy: 0.80035001039505 -- test accuracy: 0.8044642210006714
step: 3000 -- total steps is: 6000 -- train loss: 1.2919725179672241 -- train accuracy: 0.8056633472442627 -- test accuracy: 0.8093476295471191
step: 4000 -- total steps is: 6000 -- train loss: 1.2724905014038086 -- train accuracy: 0.8100699782371521 -- test accuracy: 0.8135648965835571
step: 5000 -- total steps is: 6000 -- train loss: 1.254680871963501 -- train accuracy: 0.8143399953842163 -- test accuracy: 0.8172408938407898
step: 6000 -- total steps is: 6000 -- train loss: 1.2391554117202759 -- train accuracy: 0.8179349899291992 -- test accuracy: 0.8206441402435303
"""

从结果中可以看到,虽然学生网络的结构简单,但是老师网络的知识对于学生网络的训练确实有帮助。


完整代码

from __future__ import absolute_import,unicode_literals,division,print_function

import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

print (tf.__version__)

temp = 3  # 温度参数
epochs = 10
batch_size = 100

def load_data(epochs, batch_size):
	(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

	train_images = train_images.astype('float32')
	test_images = test_images.astype('float32')
	train_images /= 255
	test_images /= 255
	
	print (train_images.shape)
	
	train_images = train_images.reshape((len(train_images), -1))
	test_images = test_images.reshape((len(test_images), -1))
	
	train_labels = tf.keras.utils.to_categorical(train_labels.astype('float32'))
	test_labels = tf.keras.utils.to_categorical(test_labels.astype('float32'))
	
	dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
	dataset = dataset.repeat(epochs).batch(batch_size)
	
	return train_images, train_labels, test_images, test_labels, dataset
	
def build_teacherNet(input_size, class_num, training = False):
    model = tf.keras.Sequential(name='teacherNet')
    model.add(tf.keras.layers.Dense(1200, input_shape=(input_size,), activation='relu'))
    if training == True:
        model.add(tf.keras.layers.Dropout(0.2))
    model.add(tf.keras.layers.Dense(1200, activation='relu'))
    model.add(tf.keras.layers.Dense(class_num, name='logits'))
    model.add(tf.keras.layers.Activation('softmax', name='softmax'))
    
    model.summary()
    return model
 
def build_studentNet(input_size, class_num):
    model = tf.keras.Sequential(name='studentNet')
    model.add(tf.keras.layers.Dense(10, input_shape=(input_size,), activation='relu'))
    model.add(tf.keras.layers.Dense(class_num, name='logits'))
    model.add(tf.keras.layers.Activation('softmax', name='softmax'))

    model.summary()
    return model

def train_teacherNet(train_images, train_labels, test_images, test_labels):
    teacherNet = build_teacherNet(train_images.shape[-1], 10, training = True)
	
    teacherNet.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 
                        loss=tf.keras.losses.CategoricalCrossentropy(), 
	                    metrics=['accuracy'])
	
    teacherNet.fit(train_images, train_labels, 
	            validation_data=(test_images, test_labels),
	            epochs=30, 
	            batch_size=100, 
	            verbose = 2)
	
    _, test_acc = teacherNet.evaluate(test_images,  test_labels)
    print(test_acc)
	
	# 保存json格式模型文件
    json_config = teacherNet.to_json()
    with open('model_config.json', 'w') as json_file:
        json_file.write(json_config)
	    
	# 保存权重
    teacherNet.save_weights('model_weights.h5', overwrite = True)
    print ('train and save teacher network complate.')

def train_studentNet(studentNet, train_images, train_labels, test_images, test_labels):
	# 单独训练学生网络
	tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='student', histogram_freq=1)
	
	studentNet.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 
	                   loss=tf.keras.losses.CategoricalCrossentropy(), 
	                   metrics=['accuracy'])
	studentNet.fit(train_images, train_labels, 
	               validation_data=(test_images, test_labels), 
	               epochs=10, 
	               batch_size=100,
	               verbose=2,
	               callbacks=[tensorboard_callback])
	       
# 知识蒸馏
def softmax_with_temp(logits, temp=1):
    logits = (logits - tf.math.reduce_max(logits)) / temp
    exp_logits = tf.math.exp(logits)
    logits_sum = tf.math.reduce_sum(exp_logits, axis=-1, keepdims=True)
    result = exp_logits / logits_sum
    return result

def custom_ce(y_true, y_soft, y_pred, y_soft_pred, alpha = 0.5):
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    origin_loss = -tf.math.reduce_mean(tf.math.reduce_sum(y_true * tf.math.log(y_pred), axis=-1, keepdims=False))

    y_soft = tf.clip_by_value(y_soft, 1e-7, 1 - 1e-7)
    y_soft_pred = tf.clip_by_value(y_soft_pred, 1e-7, 1 - 1e-7)
    soft_loss = -tf.math.reduce_mean(tf.math.reduce_sum(y_soft * tf.math.log(y_soft_pred), axis=-1, keepdims=False))
    
    return alpha * soft_loss + (1 - alpha) * origin_loss


def train():
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            pred = Kd_studentNet(images, training=True)
            unsoft_pred = softmax_with_temp(pred, 1)
            soft_pred = softmax_with_temp(pred, temp)

            teacher_logits = teacher_model_ex_softmax(images)
            softened_teacher_prob = softmax_with_temp(teacher_logits, temp)
            loss_value = custom_ce(labels, softened_teacher_prob, unsoft_pred, soft_pred)

        grads = tape.gradient(loss_value, Kd_studentNet.trainable_variables)
        opt.apply_gradients(zip(grads, Kd_studentNet.trainable_variables))

        train_acc(labels, pred)
        train_loss(loss_value)
        
        return loss_value

    ## 1. 加载数据
    train_images, train_labels, test_images, test_labels, dataset = load_data(epochs, batch_size)

    ## 2. 训练老师网络
    train_teacherNet(train_images, train_labels, test_images, test_labels)
	
    step = 0
    opt = tf.keras.optimizers.SGD(learning_rate=0.001)
    train_acc = tf.keras.metrics.CategoricalAccuracy('train_accuracy')
    test_acc = tf.keras.metrics.CategoricalAccuracy('test_accuracy')
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

	## 3. 建立学生网络
    studentNet = build_studentNet(train_images.shape[-1], 10)
    train_studentNet(studentNet, train_images, train_labels, test_images, test_labels)

    logits = studentNet.get_layer('logits').output
    Kd_studentNet = tf.keras.Model(inputs=studentNet.input, outputs=logits)

	## 4. 蒸馏训练
	# 加载训练好的老师模型
    with open('model_config.json') as json_file:
	    json_config = json_file.read()
    teacher_model = tf.keras.models.model_from_json(json_config)
    teacher_model.load_weights('model_weights.h5')	
    teacher_model_ex_softmax = tf.keras.Model(inputs=teacher_model.input, outputs=teacher_model.get_layer('logits').output)
	
	## 5. 开始训练
    for x, y in dataset:
	    loss = train_step(x, y)
	    test_acc(test_labels, studentNet(test_images, training=False))
	
	    step += 1
	
	    if step % 1000 == 0:
	        print("step: {} -- total steps is: {} -- train loss: {} -- train accuracy: {} -- test accuracy: {} ".format(step, int(len(train_images) * epochs / batch_size), train_loss.result(), train_acc.result(), test_acc.result()))

if __name__ == '__main__':
    train()

更多的关于知识蒸馏的代码可见:
Knowledge_distillation_via_TF2.0
Efficient-Neural-Network-Bilibili-pytorch

发布了295 篇原创文章 · 获赞 103 · 访问量 20万+

猜你喜欢

转载自blog.csdn.net/Forlogen/article/details/104474484
今日推荐