Lösen Sie das Problem von [TypeError: forward() hat ein unerwartetes Schlüsselwortargument „reduction“ erhalten]

Der Projekthintergrund besteht zunächst darin, BiLSTM + CRF zur Ausführung von NER-Aufgaben zu verwenden. Bei der Definition der Verlustfunktion wurde im Titel ein Fehler gemeldet. Das Folgende ist der Quellcode:

# model.py
class Model(nn.Module):
    def __init__(self):
        ......
        self.crf = CRF(TARGET_SIZE)

    def forward(self, input, mask):
        ......
        return self.crf.decode(out, mask)

    def loss_fn(self, input, target, mask):
        y_pred = self._get_lstm_feature(input)
        return -self.crf.forward(y_pred, target, mask, reduction='mean')

Der Fokus liegt auf loss_fn, daher werden die Initialisierungs- und Feed-Forward-Ebenen vorerst nicht angezeigt. Diese Schreibweise führt beim Training des Modells zu dem im Titel angezeigten Fehler, was bedeutet:

„TypeError: forward() hat ein unerwartetes Schlüsselwortargument ‚reduction‘ erhalten“

Zeigt an, dass die Vorwärtsmethode der CRF-Klasse keinen Reduktionsparameter akzeptiert.

Versuchen Sie, den Reduktionsparameter zu entfernen, trainieren Sie das Modell erneut und melden Sie erneut den folgenden Fehler:

RuntimeError: grad kann nur für skalare Ausgaben implizit erstellt werden

„RuntimeError: grad kann nur implizit für skalare Ausgabe erstellt werden“

Das bedeutet, dass die Verlustfunktion einen Tensor mit mehreren Elementen zurückgibt, was bei der Berechnung von Gradienten nicht zulässig ist. Daher kann der Reduktionsparameter nicht entfernt werden und loss_fn sollte beibehalten werden, um einen Skalarwert zurückzugeben.

Schließlich wird eine andere Schreibweise geändert, die normal ausgeführt und als Modell für das Training verwendet werden kann:

    def loss_fn(self, input, target, mask):
        y_pred = self._get_lstm_feature(input)
        loss = -self.crf.forward(y_pred, target, mask)
        return loss.mean()

 Verwenden Sie zunächst eine Verlustvariable, um den Verlustwert zu empfangen, und verwenden Sie dann bei der Rückgabe „mean()“.

Ich hoffe, es hilft dir

Supongo que te gusta

Origin blog.csdn.net/weixin_45206129/article/details/130310085
Recomendado
Clasificación