多任务学习(一)

多任务学习

单任务学习

样本之间没有关联性。

缺点:训练出来的模型不具有泛化性;不共享信息使得学习能力下降。

多任务学习

多任务学习的构建原则

  • 建模任务之间的相关性
  • 同时对多个任务的模型参数进行联合学习,挖掘其中的共享信息;
  • 考虑 人物之间的差异性,增强模型的适应能力。

多任务学习的两种主要方式

  • 基于参数的共享

    例如:神经网络隐层节点的共享

  • 基于正则化约束的共享

    例如:均值约束、联合特征学习等。

参数共享多任务学习

在这里插入图片描述

单任务学习中,每个单任务建立一个模型。每个任务的模型不适用于另一个任务。

基于正则化约束的共享

在这里插入图片描述

例子:

1.多输入多输出(主辅任务)

# 导入包
from keras.layers import Input,Embedding,LSTM,Dense
from keras.models import Model
import keras

#搭建模型
main_input=Input(shape=(100,),dtype='int32',name='main_input')
x=Embedding(output_dim=512,input_dim=10000,input_length=100)(main_input)
lstm_out=LSTM(32)(x)

# 辅助输出
auxiliary_output=Dense(1,activation='sigmoid',name='aux_output')(lstm_out)

# 辅助输入
auxiliary_input=Input(shape=(5,),name='aux_input')
x=keras.layers.concatenate([lstm_out,auxiliary_input])

x=Dense(64,activation='relu')(x)
x=Dense(64,activation='relu')(x)
x=Dense(64,activation='relu')(x)
# 主要输出
main_output=Dense(1,activation='sigmoid',name='main_output')(x)

model=Model(inputs=[main_input,auxiliary_input],outputs=[main_output,auxiliary_output])
# 查看模型结构
model.summary()

在这里插入图片描述

# 查看模型网络结构图并保存
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes='true')

在这里插入图片描述

# 编译模型
model.compile(optimizer='rmsprop',
              loss={
    
    'main_output':'binary_crossentropy','aux_output':'binary_crossentropy'},
              loss_weights={
    
    'main_output':1.,'aux_output':0.2})

model.fit({
    
    'main_input':headline_data,'aux_input':additional_data},
          {
    
    'main_output':labels,'aux_output':labels},
          epochs=50,batch_size=32)

2.共享网络层

# 导入包
import keras
from keras.layers import Input,LSTM,Dense
from keras.models import Model

# 搭建模型
tweet_a=Input(shape=(280,256))
tweet_b=Input(shape=(280,256))

shared_lstm=LSTM(64)

encoded_a=shared_lstm(tweet_a)
encoded_b=shared_lstm(tweet_b)

merged_vector=keras.layers.concatenate([encoded_a,encoded_b],axis=-1)

predictions=Dense(1,activation='sigmoid')(merged_vector)

model=Model(inputs=[tweet_a,tweet_b],outputs=predictions)
# 编译模型
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model.fit([data_a,data_b],labels,epochs=10)
# 查看模型结构
model.summary()

# 查看模型网络图
from tensorflow.keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes='true')

在这里插入图片描述

查看各节点输出

a = Input(shape=(280, 256))

lstm = LSTM(32)
encoded_a = lstm(a)

lstm.output == encoded_a
# >>><KerasTensor: shape=(None, 32) dtype=bool (created by layer 'tf.__operators__.eq_3')>

参考文献:
https://keras.io/zh/getting-started/functional-api-guide/

猜你喜欢

转载自blog.csdn.net/fyfy96/article/details/121283623