多任务学习
单任务学习
样本之间没有关联性。
缺点:训练出来的模型不具有泛化性;不共享信息使得学习能力下降。
多任务学习
多任务学习的构建原则
- 建模任务之间的相关性
- 同时对多个任务的模型参数进行联合学习,挖掘其中的共享信息;
- 考虑 人物之间的差异性,增强模型的适应能力。
多任务学习的两种主要方式
-
基于参数的共享
例如:神经网络隐层节点的共享
-
基于正则化约束的共享
例如:均值约束、联合特征学习等。
参数共享多任务学习
单任务学习中,每个单任务建立一个模型。每个任务的模型不适用于另一个任务。
基于正则化约束的共享
例子:
1.多输入多输出(主辅任务)
# 导入包
from keras.layers import Input,Embedding,LSTM,Dense
from keras.models import Model
import keras
#搭建模型
main_input=Input(shape=(100,),dtype='int32',name='main_input')
x=Embedding(output_dim=512,input_dim=10000,input_length=100)(main_input)
lstm_out=LSTM(32)(x)
# 辅助输出
auxiliary_output=Dense(1,activation='sigmoid',name='aux_output')(lstm_out)
# 辅助输入
auxiliary_input=Input(shape=(5,),name='aux_input')
x=keras.layers.concatenate([lstm_out,auxiliary_input])
x=Dense(64,activation='relu')(x)
x=Dense(64,activation='relu')(x)
x=Dense(64,activation='relu')(x)
# 主要输出
main_output=Dense(1,activation='sigmoid',name='main_output')(x)
model=Model(inputs=[main_input,auxiliary_input],outputs=[main_output,auxiliary_output])
# 查看模型结构
model.summary()
# 查看模型网络结构图并保存
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes='true')
# 编译模型
model.compile(optimizer='rmsprop',
loss={
'main_output':'binary_crossentropy','aux_output':'binary_crossentropy'},
loss_weights={
'main_output':1.,'aux_output':0.2})
model.fit({
'main_input':headline_data,'aux_input':additional_data},
{
'main_output':labels,'aux_output':labels},
epochs=50,batch_size=32)
2.共享网络层
# 导入包
import keras
from keras.layers import Input,LSTM,Dense
from keras.models import Model
# 搭建模型
tweet_a=Input(shape=(280,256))
tweet_b=Input(shape=(280,256))
shared_lstm=LSTM(64)
encoded_a=shared_lstm(tweet_a)
encoded_b=shared_lstm(tweet_b)
merged_vector=keras.layers.concatenate([encoded_a,encoded_b],axis=-1)
predictions=Dense(1,activation='sigmoid')(merged_vector)
model=Model(inputs=[tweet_a,tweet_b],outputs=predictions)
# 编译模型
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
model.fit([data_a,data_b],labels,epochs=10)
# 查看模型结构
model.summary()
# 查看模型网络图
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes='true')
查看各节点输出
a = Input(shape=(280, 256))
lstm = LSTM(32)
encoded_a = lstm(a)
lstm.output == encoded_a
# >>><KerasTensor: shape=(None, 32) dtype=bool (created by layer 'tf.__operators__.eq_3')>
参考文献:
https://keras.io/zh/getting-started/functional-api-guide/