glamredhel:
Cuando se utiliza la transferencia de aprendizaje en Tensorflow, sé que uno puede bloquear capas de la formación continua, haciendo:
for layer in pre_trained_model.layers:
layer.trainable = False
¿Es posible bloquear filtros específicos en la capa en su lugar? Al igual que en - si toda la capa contiene 64 filtros, es posible:
- cerradura sólo algunos de ellos, que parecen contener filtros razonables y re-entrenar a los que no lo hacen?
O
- quitar los filtros injustificadamente a futuro a partir de capas y volver a entrenar sin ellos? (Por ejemplo, para ver si los filtros nueva formación va a cambiar mucho)
Vlad:
Una solución posible es implementar capa personalizada que divide convolución en distintos number of filters
circunvoluciones y conjuntos de cada canal (que es una convolución con un canal de salida) a trainable
o a not trainable
. Por ejemplo:
import tensorflow as tf
import numpy as np
class Conv2DExtended(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, **kwargs):
self.filters = filters
self.conv_layers = [tf.keras.layers.Conv2D(1, kernel_size, **kwargs) for _ in range(filters)]
super().__init__()
def build(self, input_shape):
_ = [l.build(input_shape) for l in self.conv_layers]
super().build(input_shape)
def set_trainable(self, channels):
"""Sets trainable channels."""
for i in channels:
self.conv_layers[i].trainable = True
def set_non_trainable(self, channels):
"""Sets not trainable channels."""
for i in channels:
self.conv_layers[i].trainable = False
def call(self, inputs):
results = [l(inputs) for l in self.conv_layers]
return tf.concat(results, -1)
Y ejemplo de uso:
inputs = tf.keras.layers.Input((28, 28, 1))
conv = Conv2DExtended(filters=4, kernel_size=(3, 3))
conv.set_non_trainable([1, 2]) # only channels 0 and 3 are trainable
res = conv(inputs)
res = tf.keras.layers.Flatten()(res)
res = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(res)
model = tf.keras.models.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(),
loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(np.random.normal(0, 1, (10, 28, 28, 1)),
np.random.randint(0, 2, (10)),
batch_size=2,
epochs=5)