先尝试了Embedding的用法:
介绍看这里:https://keras-cn.readthedocs.io/en/latest/layers/embedding_layer/#embedding_1
from keras import backend as K
from keras.layers import Dense, Embedding, Activation, Permute,Input, Flatten, Dropout
from keras.layers.core import Lambda
def GetModel():
inputs = Input(shape=(66,))
net = Embedding(50,13)(inputs)
net = Lambda(lambda x :K.expand_dims(x,-1))(net)
net = Conv2D(7,3)(net)
model = Model(inputs=inputs, outputs=net)
model.summary()
'''
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 66) 0
_________________________________________________________________
embedding_1 (Embedding) (None, 66, 13) 650
_________________________________________________________________
lambda_1 (Lambda) (None, 66, 13, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 11, 7) 70
=================================================================
Total params: 720
Trainable params: 720
Non-trainable params: 0
_________________________________________________________________
'''
看了下两个包装器Wrapper,很有用:https://keras-cn.readthedocs.io/en/latest/layers/wrapper/#bidirectional
关于Keras LSTM一些参数的解释:https://blog.csdn.net/u011327333/article/details/78501054
一些试验:
def GetModel():
inputs = Input(shape=(66,))
net = Embedding(50,13)(inputs)
net = Lambda(lambda x :K.expand_dims(x,-1))(net)
net = Conv2D(7,3)(net)
net = Reshape( (-1,77) )(net)
net = LSTM(666)(net)
model = Model(inputs=inputs, outputs=net)
model.summary()
'''
=================================================================
input_1 (InputLayer) (None, 66) 0
_________________________________________________________________
embedding_1 (Embedding) (None, 66, 13) 650
_________________________________________________________________
lambda_1 (Lambda) (None, 66, 13, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 11, 7) 70
_________________________________________________________________
reshape_1 (Reshape) (None, 64, 77) 0
_________________________________________________________________
lstm_1 (LSTM) (None, 666) 1982016
=================================================================
'''
def GetModel():
inputs = Input(shape=(66,))
net = Embedding(50,13)(inputs)
net = Lambda(lambda x :K.expand_dims(x,-1))(net)
net = Conv2D(7,3)(net)
net = Reshape( (-1,77) )(net)
net = LSTM(666,return_sequences=True)(net)
model = Model(inputs=inputs, outputs=net)
model.summary()
'''
=================================================================
input_1 (InputLayer) (None, 66) 0
_________________________________________________________________
embedding_1 (Embedding) (None, 66, 13) 650
_________________________________________________________________
lambda_1 (Lambda) (None, 66, 13, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 11, 7) 70
_________________________________________________________________
reshape_1 (Reshape) (None, 64, 77) 0
_________________________________________________________________
lstm_1 (LSTM) (None, 64, 666) 1982016
=================================================================
'''
def GetModel():
inputs = Input(shape=(66,))
net = Embedding(50,13)(inputs)
net = Lambda(lambda x :K.expand_dims(x,-1))(net)
net = Conv2D(7,3)(net)
net = Reshape( (-1,77) )(net)
net = Bidirectional( LSTM(666,return_sequences=True,return_state=False) )(net)
net = Bidirectional( LSTM(520,return_sequences=True,return_state=False) )(net)
net = Dense(20)(net)
net = Flatten()(net)
model = Model(inputs=inputs, outputs=net)
model.summary()
1/0
'''
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 66) 0
_________________________________________________________________
embedding_1 (Embedding) (None, 66, 13) 650
_________________________________________________________________
lambda_1 (Lambda) (None, 66, 13, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 11, 7) 70
_________________________________________________________________
reshape_1 (Reshape) (None, 64, 77) 0
_________________________________________________________________
bidirectional_1 (Bidirection (None, 64, 1332) 3964032
_________________________________________________________________
bidirectional_2 (Bidirection (None, 64, 1040) 7708480
_________________________________________________________________
dense_1 (Dense) (None, 64, 20) 20820
_________________________________________________________________
flatten_1 (Flatten) (None, 1280) 0
=================================================================
'''