TF の DNN: Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズム (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイル形式でのロード) を使用したタイタニック号データ セット (ワンホット エンコーディング) に基づいて、分類予測を実現します。アプリケーションの場合

TF の DNN: Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズム (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイル形式でのロード) を使用したタイタニック号データ セット (ワンホット エンコーディング) に基づいて、分類予測を実現します。アプリケーションの場合

目次

Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズム (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイル形式でのロード) を使用したタイタニック号データ セット (ワンホット エンコーディング) に基づいて、分類予測アプリケーション ケースを実現します。

# 1. データセットを定義する

# 金型フィーチャーを定義する

#2. データの前処理

#2.1、欠損値の処理

# 2.2、機能エンコーディング

# T1、ワンホットエンコーディング

# T2、ラベルエンコーディング

# 2.3. 個別の機能とラベル

#3. モデルのトレーニングと評価

# 3.1. セグメンテーションデータセット

# 3.2. モデルを構築する

#3.3. モデルのトレーニングと予測

# モデルの予測結果を出力する

# 3.4. モデルの評価: AUC、F1

#3.5. モデルのエクスポートと推論

# T1、推論のためにモデル ファイル (構造とパラメーター) をエクスポートしてロードします

# T2、モデルの重みパラメータのエクスポートとロード

# T3. モデル構造のエクスポートとロード (ただし、推論を実現するにはモデルの重みをロードする必要があります): JSON ファイルからモデル構造をロードし、load_weights() 関数を使用してモデルの重みパラメーターをロードし、完全なモデルを復元します。

T4、nx モデル ファイルのエクスポートとロード


関連記事
TF の DNN: Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズム (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイル形式でのロード) を使用したタイタニック号のデータ セット (ワンホット エンコーディング) に基づいて、次のことを実現します。分類 予測アプリケーションケース
TF の DNN: タイタニック データセット (ワンホット エンコーディング) に基づき、Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズムを使用 (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイルへのロード)フォーマット)実装 分類予測アプリケーションケースの実装コード

Tensorflow フレームワークの浅いニューラル ネットワーク アルゴリズム (h5 モデル ファイル/モデル パラメーター/json モデル構造/モデルのエクスポートと onnx ファイル形式でのロード) を使用したタイタニック号データ セット (ワンホット エンコーディング) に基づいて、分類予測アプリケーション ケースを実現します。

# 1. データセットを定義する

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None
   PassengerId  Survived  Pclass  ...     Fare Cabin  Embarked
0            1         0       3  ...   7.2500   NaN         S
1            2         1       1  ...  71.2833   C85         C
2            3         1       3  ...   7.9250   NaN         S
3            4         1       1  ...  53.1000  C123         S
4            5         0       3  ...   8.0500   NaN         S

[5 rows x 12 columns]

# 金型フィーチャーを定義する

after featuresIN………………………………………………
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 6 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  891 non-null    int64  
 1   Pclass    891 non-null    int64  
 2   Age       714 non-null    float64
 3   Fare      891 non-null    float64
 4   Sex       891 non-null    object 
 5   Embarked  889 non-null    object 
dtypes: float64(2), int64(2), object(2)
memory usage: 41.9+ KB
None

#2. データの前処理

#2.1、欠損値の処理

# 2.2、機能エンコーディング

# T1、ワンホットエンコーディング

# T2、ラベルエンコーディング

OHEncode………………………………………………
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Survived    891 non-null    int64  
 1   Pclass      891 non-null    int64  
 2   Age         891 non-null    float64
 3   Fare        891 non-null    float64
 4   Sex_female  891 non-null    uint8  
 5   Sex_male    891 non-null    uint8  
 6   Embarked_C  891 non-null    uint8  
 7   Embarked_Q  891 non-null    uint8  
 8   Embarked_S  891 non-null    uint8  
dtypes: float64(2), int64(2), uint8(5)
memory usage: 32.3 KB
None

# 2.3. 個別の機能とラベル

#3. モデルのトレーニングと評価

# 3.1. セグメンテーションデータセット

# 3.2. モデルを構築する

#3.3. モデルのトレーニングと予測

Epoch 1/100
23/23 [==============================] - 0s 931us/step - loss: 0.9403 - accuracy: 0.5632
Epoch 2/100
23/23 [==============================] - 0s 689us/step - loss: 0.5955 - accuracy: 0.7079
Epoch 3/100
23/23 [==============================] - 0s 647us/step - loss: 0.5803 - accuracy: 0.7205
Epoch 4/100
23/23 [==============================] - 0s 679us/step - loss: 0.5508 - accuracy: 0.7275
Epoch 5/100
23/23 [==============================] - 0s 639us/step - loss: 0.5525 - accuracy: 0.7247
Epoch 6/100

……
Epoch 96/100
23/23 [==============================] - 0s 668us/step - loss: 0.4455 - accuracy: 0.7949
Epoch 97/100
23/23 [==============================] - 0s 697us/step - loss: 0.4518 - accuracy: 0.7879
Epoch 98/100
23/23 [==============================] - 0s 680us/step - loss: 0.4501 - accuracy: 0.8006
Epoch 99/100
23/23 [==============================] - 0s 699us/step - loss: 0.4445 - accuracy: 0.8020
Epoch 100/100
23/23 [==============================] - 0s 632us/step - loss: 0.4461 - accuracy: 0.8020
6/6 [==============================] - 0s 756us/step

# モデルの予測結果を出力する

# 3.4. モデルの評価: AUC、F1

AUC: 0.8761
F1 score: 0.7385
6/6 [==============================] - 0s 798us/step

#3.5. モデルのエクスポートと推論

# T1、推論のためにモデル ファイル (構造とパラメーター) をエクスポートしてロードします

AUC: 0.8761
F1 score: 0.7385
6/6 [==============================] - 0s 798us/step

model_h5 -------------------

     PassengerId  Survived  loaded_model_y_prob
0            172  0.776513             0.776513
1            524  0.112160             0.112160
2            452  0.438001             0.438001
3            170  0.225820             0.225820
4            620  0.154560             0.154560
..           ...       ...                  ...
174          388  0.115106             0.115106
175          338  0.054891             0.054891
176          827  0.680395             0.680395
177          773  0.112122             0.112122
178          221  0.173132             0.173132

[179 rows x 3 columns]

# T2、モデルの重みパラメータのエクスポートとロード

AUC: 0.8773
F1 score: 0.7188
weights_h5 -------------------
6/6 [==============================] - 0s 811us/step
     PassengerId  Survived  loaded_model_weights_y_probt
0            172  0.708009                      0.708009
1            524  0.105870                      0.105870
2            452  0.466656                      0.466656
3            170  0.237784                      0.237784
4            620  0.168124                      0.168124
..           ...       ...                           ...
174          388  0.122256                      0.122256
175          338  0.058414                      0.058414
176          827  0.491990                      0.491990
177          773  0.105840                      0.105840
178          221  0.189526                      0.189526

[179 rows x 3 columns]

# T3. モデル構造のエクスポートとロード (ただし、推論を実現するにはモデルの重みをロードする必要があります): JSON ファイルからモデル構造をロードし、load_weights() 関数を使用してモデルの重みパラメーターをロードし、完全なモデルを復元します。

AUC: 0.8706
F1 score: 0.6984
model_json -------------------
6/6 [==============================] - 0s 547us/step
     PassengerId  Survived  loaded_model_json_y_prob
0            172  0.697425                  0.697425
1            524  0.132381                  0.132381
2            452  0.359049                  0.359049
3            170  0.161974                  0.161974
4            620  0.134586                  0.134586
..           ...       ...                       ...
174          388  0.116566                  0.116566
175          338  0.076110                  0.076110
176          827  0.345227                  0.345227
177          773  0.132386                  0.132386
178          221  0.147759                  0.147759

[179 rows x 3 columns]

Process finished with exit code 0

T4、nx モデル ファイルのエクスポートとロード

AUC: 0.8802
F1 score: 0.7482
model_onnx -------------------
dense_input
     PassengerId  Survived  loaded_model_ONNX_y_prob
0            172  0.808569                  0.808569
1            524  0.215418                  0.215418
2            452  0.462289                  0.462289
3            170  0.361842                  0.361842
4            620  0.241996                  0.241996
..           ...       ...                       ...
174          388  0.108243                  0.108243
175          338  0.056558                  0.056558
176          827  0.738313                  0.738313
177          773  0.215353                  0.215353
178          221  0.174459                  0.174459

[179 rows x 3 columns]

おすすめ

転載: blog.csdn.net/qq_41185868/article/details/130652148