keras が提供する損失関数 API を使用すると、勾配を逆伝播することができず、損失関数が減少しません。
質問:
keras が提供する損失関数 API を使用する場合、勾配を逆伝播することはできません
コード:
from tensorflow.keras.losses import categorical_crossentropy
def train_generator(x, y, z, eps, dcgan, siamese_model, loss=None):
with tf.GradientTape(persistent=True) as t:
fake_x = dcgan.generator([z, y])
loss_G = -tf.reduce_mean(dcgan.discriminator(fake_x))
preds = aux_model(fake_x)
aux_mean = categorical_crossentropy(y, preds)
aux_loss = tf.reduce_mean(aux_mean)
total_loss = aux_loss + loss_G
gradient_g = t.gradient(total_loss, dcgan.generator.trainable_variables)
dcgan.optimizer_G.apply_gradients(zip(gradient_g, dcgan.generator.trainable_variables))
理由を推測してください:
Keras インターフェースは最初にデータを前処理してから tensorflow のバックエンドを呼び出すことがありますが、これにより関数勾配チェーンが切断され、チェーン導出を通じて勾配降下を実行できなくなります。
keras ソース コードで categorical_crossentropy がどのように定義されているかを確認してください。
root@Ie1c58c4ee0020126c:~# find / -iname keras
/usr/local/lib/python3.7/dist-packages/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/contrib/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/contrib/keras/api/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/_v1/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/_v2/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/keras
vim /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/losses.py
loss.py
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
"""Computes the categorical crossentropy loss.
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
Categorical crossentropy loss value.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
keras インターフェイスのデータに対して前処理が行われ、その後 tensorflow バックエンド (K) の categorical_crossentropy インターフェイスが呼び出されたことがわかります。これにより、勾配チェーンが切断され、チェーン導出と逆伝播を通じて勾配を更新できなくなりました。
解決:
損失関数を自分で実装するか、keras の loss.py ファイルでソース コードを見つけて、tensorflow バックエンドによって提供されるインターフェイスを損失関数として直接呼び出します。
最後にコードを次のように変更しました。
from tensorflow.keras import backend as K
def train_generator(x, y, z, eps, dcgan, siamese_model, loss=None):
with tf.GradientTape(persistent=True) as t:
fake_x = dcgan.generator([z, y])
loss_G = -tf.reduce_mean(dcgan.discriminator(fake_x))
preds = aux_model(fake_x)
aux_mean = K.categorical_crossentropy(y, preds)
aux_loss = tf.reduce_mean(aux_mean)
total_loss = aux_loss + loss_G
gradient_g = t.gradient(total_loss, dcgan.generator.trainable_variables)
dcgan.optimizer_G.apply_gradients(zip(gradient_g, dcgan.generator.trainable_variables))
完璧に解決しましたので、同じ悩みを抱えている方はぜひ参考にしてください。