FCN结构in Speech的Keras简单实现
本文通过Keras简单实现了一种FCN结构in Speech。Github有其他的实现代码,本文仅是通过自己的理解,对参考文献[1]中的网络进行搭建。若对其它代码有兴趣,请移步Github[2] (不清楚是否为论文作者创作的源代码)。
参考文献:
[1] Z. Ouyang, H. Yu, W. Zhu and B. Champagne, “A Fully Convolutional Neural Network for Complex Spectrogram Processing in Speech Enhancement,” ICASSP 2019 - 2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Brighton, United Kingdom, 2019, pp. 5756-5760, doi: 10.1109/ICASSP.2019.8683423.
[2] https://github.com/phpstorm1/SE-FCN
上代码
// FCN
def FCN_ICASSP(self):
frame_number_one_sample = self.frame_number_one_sample
train_input_logstft = self.train_input
train_output_target = self.train_output
// 根据需要切割验证集 代码省略
// 记录epoch时间函数
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={
}):
self.times = []
self.totaltime = time.time()
def on_train_end(self, logs={
}):
self.totaltime = time.time() - self.totaltime
def on_epoch_begin(self, batch, logs={
}):
self.epoch_time_start = time.time()
def on_epoch_end(self, batch, logs={
}):
self.times.append(time.time() - self.epoch_time_start)
// x根据需要设置。设置依据:STFT后的Frequency bin
input = Input(shape=(x, frame_number_one_sample, 1))
// conv2d layers. 利用Conv2d实现Conv1d的功能(也可直接替换成Keras中的Conv1d)
x1_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(input)
x1 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same')(input)
x1_act = layers.Activation('relu')(x1)
x1_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x1_act)
x1_res_layers = layers.add([x1_1d_skip, x1_res])
x2_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x1_res_layers)
x2 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(2, 1), padding='same')(x1_res_layers)
x2_act = layers.Activation('relu')(x2)
x2_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x2_act)
x2_res_layers = layers.add([x2_1d_skip, x2_res])
x3_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x2_res_layers)
x3 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(4, 1), padding='same')(x2_res_layers)
x3_act = layers.Activation('relu')(x3)
x3_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x3_act)
x3_res_layers = layers.add([x3_1d_skip, x3_res])
x4_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x3_res_layers)
x4 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(8, 1), padding='same')(x3_res_layers)
x4_act = layers.Activation('relu')(x4)
x4_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x4_act)
x4_res_layers = layers.add([x4_1d_skip, x4_res])
x5_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x4_res_layers)
x5 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(16, 1), padding='same')(x4_res_layers)
x5_act = layers.Activation('relu')(x5)
x5_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x5_act)
x5_res_layers = layers.add([x5_1d_skip, x5_res])
x6_res = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x5_res_layers)
x6 = layers.Conv2D(48, (5, 3), strides=(1, 1), dilation_rate=(32, 1), padding='same')(x5_res_layers)
x6_act = layers.Activation('relu')(x6)
x6_1d_skip = layers.Conv2D(48, (1, 1), strides=(1, 1), dilation_rate=(1, 1))(x6_act)
x6_res_layers = layers.add([x6_1d_skip, x6_res])
skip_connection = layers.add([x1_1d_skip, x2_1d_skip, x3_1d_skip, x4_1d_skip, x5_1d_skip, x6_1d_skip])
// 对feature map进行切片
def slice(x, index):
return x[:,:,index,:]
slice_layers = layers.Lambda(slice, output_shape=(x, 1, 48), arguments={
'index':6})(skip_connection)
reshape_layers2 = layers.Reshape((x, 1, 48))(slice_layers)
// conv1d layers
x7_1d_skip = layers.Conv2D(96, (3, 1), strides=(1, 1), dilation_rate=(1, 1), padding='same')(reshape_layers2)
x7_act = layers.Activation('relu')(x7_1d_skip)
x8_1d_skip = layers.Conv2D(1, (3, 1), activation='sigmoid',strides=(1, 1), dilation_rate=(1, 1), padding='same')(x7_act)
model = Model(input, x8_1d_skip)
model.compile(optimizer = 'adam',
loss='binary_crossentropy')
model.summary()
time_callback = TimeHistory()
// epoch 和 batch根据情况改动
model.fit(partial_train_input_logstft,
partial_train_output_targetsnr,
epochs = 100,
batch_size = 96,
callbacks=[time_callback],
validation_data = (val_train_input_logstft,val_train_output_targetsnr)
)
print(time_callback.times)
print(time_callback.totaltime)
model.save('FCN_ICASSP_model.h5')
print('model have train all ready')