tf.kerasのクイックスタート-カスタムインジケーター

1はじめに

この記事では、tensorflow2.0議論中のカスタムインジケーターの問題について説明します。公式サイトアドレス:こちら

2.背景

tf.keras.metrics.Metricクラスからサブクラスへのカスタムインジケーターを簡単に作成できます42つのメソッドを実装する必要があります。

  • __init__(self)、インジケーターの状態変数を作成する場所。
  • update_state(self, y_true, y_pred, sample_weight=None)、ターゲットy_trueとモデルの予測を使用して、y_pred状態変数を更新しました。
  • result(self)、状態変数を使用して最終結果を計算します。
  • reset_states(self)、インジケーターの状態を再初期化するために使用されます。

2.1ケース

これは、インターネット上のドキュメントと組み合わせたアイリス分類の例であり、小さな修理ケースを示しています。

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_iris
x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类

class CategoricalTruePositives(tf.keras.metrics.Metric):
    def __init__(self, name="categorical_true_positives", **kwargs):
        super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name="ctp", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        print(y_pred)
        print(y_true)
        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values)) # 自加操作,还有个assign_sub

    def result(self):
        return self.true_positives / 100

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.true_positives.assign(0.0)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(4,), activation='relu'))
model.add(tf.keras.layers.Dense(3, input_shape=(4,), activation='softmax'))

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=[CategoricalTruePositives(), 'mse'],
)
history = model.fit(x_data, y_data, batch_size=64, epochs=30)

for key in history.history.keys():
    plt.plot(history.epoch, history.history[key])
    
plt.legend(history.history.keys())

ここに画像の説明を挿入します

おすすめ

転載: blog.csdn.net/qq_26460841/article/details/113651615