[ディープラーニング] 実験 10 Keras を使用してロジスティック回帰を完了する

Keras を使用した完全なロジスティック回帰

Keras は、ニューラル ネットワークと深層学習モデルを効率的に実装できるオープンソースの深層学習フレームワークです。ニューヨーク大学の Francois Chollet によって開発されたこの API は、開発者が時間と労力を節約してモデルを迅速に構築できるように、使いやすい高レベルの API を提供することを目的としています。Keras は、TensorFlow、Theano、CNTK など、さまざまな基盤となる深層学習フレームワークと互換性があります。使いやすさと柔軟性により、ディープラーニングの分野で最も人気のあるフレームワークの 1 つとなっています。

Keras は、データからモデルまでの深層学習をより簡単かつ迅速に行えるように設計されています。Keras を深層学習に使用する場合、ニューラル層、活性化関数、オプティマイザー、損失関数などのハイパーパラメータを定義するために複数行のコードを記述する必要はなく、1 行のコードだけで済みます。さらに、Keras は、画像分類、自然言語処理、テキスト分類、シーケンス分析などのタスクを処理するために使用できる豊富な事前トレーニング済みモデルも提供するため、深層学習モデルの開発とトレーニング時間を大幅に削減できます。

Keras には次の機能もあります。

  1. シンプルで使いやすい: Keras は Python で書かれており、シンプルな API インターフェイスを提供するため、ユーザーはモデルの設計と調整にさらに注意を払うことができます。

  2. 拡張が簡単: Keras は、TensorFlow、Theano、CNTK などのさまざまな深層学習フレームワークと互換性があり、そのコンピューティング能力を効率的なトレーニングと推論に使用できます。

  3. 迅速な実装: Keras はさまざまな事前トレーニング済みモデルを提供するため、モデルを最初から開発する必要がなくなり、高品質の深層学習モデルを迅速に構築できます。

  4. 複数の言語をサポート: Keras は、Python プログラミング言語だけでなく、R や Java などの他のプログラミング言語もサポートします。

  5. オープン ソース コミュニティ: Keras には、GitHub 上に巨大なオープン ソース コミュニティがあり、開発者が学び、より良く使用するための豊富なチュートリアルとサンプルが提供されています。

つまり、Keras は、深層学習モデルを実装するための使いやすく効率的なフレームワークであり、深層学習モデルの開発と実装の効率を大幅に向上させることができます。

1. Keras ライブラリをインポートする

# 导入相关库
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from sklearn import datasets

import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
Using TensorFlow backend.

2. データセットの生成

# 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。
X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0,
                                   n_repeated = 0, n_classes = 2, n_clusters_per_class = 1)
X, y
   (array([[ 0.26364611,  0.77250816],
           [ 0.91698377,  0.9802208 ],
           [ 0.82634329,  0.9821341 ],
           [-0.83833456,  0.88223515],
           [ 1.11509338,  0.98632275],
           [ 1.04196821,  0.97892474],
           [ 0.77695264,  1.06320914],
           [-2.16804253,  0.15267335],
           [-1.96973867,  0.99244728],
           [-1.35368845,  1.25840447],
           [-0.52455148,  2.2351536 ],
           [ 1.08554563,  1.03795405],
           [ 0.88261697,  0.97793289],
           [-1.03718795,  0.53830131],
           [ 0.94628633,  0.96289949],
           [ 1.16190683,  1.01806263],
           [-2.07795249,  0.32376505],
           [ 0.9370119 ,  1.01060097],
           [ 0.92750449,  0.98713143],
           [-0.35800128,  1.4498587 ],
           [-0.96709704,  1.77632874],
           [-0.55995817,  1.58782776],
           [ 0.88919948,  1.00133032],
           [ 1.16465115,  1.05117935],
           [-1.6969619 ,  1.80088135],
           [ 1.06292602,  1.04594288],
           [-0.07792111,  0.98391779],
           [-1.05188451,  1.26871626],
           [-0.83494005,  0.93958161],
           [ 1.10371115,  1.03558148],
           [ 0.98674372,  1.04567265],
           [-1.08345028,  1.18601788],
           [-2.06487683,  0.17118219],
           [ 1.02734931,  0.99326938],
           [-0.11345441,  1.08515199],
           [ 0.97705823,  1.01751506],
           [-0.10872522,  0.91580496],
           [-1.27087508, -0.19146954],
           [ 0.87616438,  0.97685435],
           [ 0.89526079,  0.98651642],
           [ 0.96521071,  1.0206381 ],
           [ 1.0530243 ,  0.93365071],
           [ 0.994778  ,  0.99724912],
           [ 0.98176246,  1.03168734],
           [ 0.74458014,  0.97066564],
           [ 0.91748012,  0.9524803 ],
           [-1.92749946,  0.07784549],
           [ 0.7790389 ,  0.95517882],
           [ 0.11824333,  1.81065221],
           [ 0.97490265,  0.95326328],
           [ 1.00355225,  0.96521073],
           [ 1.08398178,  0.97814922],
           [-1.0749128 ,  1.77825305],
           [ 0.74886096,  1.39448605],
           [-0.1950267 ,  1.57178284],
           [ 1.069671  ,  0.97202065],
           [ 0.85757149,  1.01910676],
           [-1.02014343,  1.14016873],
           [-1.25252256,  0.02906454],
           [ 0.93948239,  1.44153932],
           [ 1.28777891,  1.00133477],
           [-1.7010408 ,  0.0821629 ],
           [ 0.8390028 ,  0.97712472],
           [ 0.99480479,  1.05717262],
           [ 1.20707509,  0.97462669],
           [-2.18786288,  1.4515569 ],
           [ 1.16027197,  1.09086817],
           [ 1.02771087,  0.9907291 ],
           [ 0.71829704,  0.98817911],
           [ 0.88605935,  0.99158972],
           [ 1.03589316,  0.99557438],
           [ 1.15489923,  0.95378093],
           [ 1.0668616 ,  0.99316509],
           [ 1.04848333,  1.09471239],
           [-1.05108888, -0.071106  ],
           [-1.19977682,  1.49257613],
           [ 1.12232276,  0.99293853],
           [-0.36977293,  1.59581   ],
           [-0.27363841,  1.46272407],
           [ 1.18075342,  0.95907983],
           [ 1.01486256,  0.97501177],
           [-0.41533403,  1.72366429],
           [-0.18337732,  2.26674615],
           [ 1.06777804,  1.00982417],
           [ 1.17411206,  0.98088369],
           [ 0.95355889,  1.05238272],
           [-0.39459255,  1.97600217],
           [ 0.90103447,  0.94080238],
           [ 0.87268023,  1.00348657],
           [-1.93323667,  1.04826094],
           [ 0.10460058,  1.16348717],
           [-1.85815599,  1.32669461],
           [ 0.90426972,  0.97521677],
           [-0.58409513,  0.9870014 ],
           [-1.74011619, -0.21416096],
           [-1.51931589,  0.34938829],
           [ 1.02631005,  0.99378866],
           [ 1.02869184,  0.99995857],
           [ 0.79862419,  1.00291807],
           [-1.34714457,  0.78937109],
           [-2.54273315,  0.96748855],
           [-1.86729291,  0.37250653],
           [-0.89843699,  0.43898384],
           [-1.83077543,  0.43636701],
           [-0.89141966,  1.57275938],
           [-0.96662858,  0.8196104 ],
           [ 0.87417528,  1.00989496],
           [ 0.93997582,  0.95616278],
           [-1.85338565,  1.00940185],
           [ 0.89565224,  0.95460192],
           [-0.76327569,  0.93526008],
           [-1.78345269,  1.53378105],
           [ 0.77408528,  1.01387371],
           [-1.47669576,  1.43472266],
           [ 1.19417792,  1.0440538 ],
           [ 1.15595665,  0.96823244],
           [ 0.84068935,  1.01792225],
           [ 1.11747629,  1.05722511],
           [ 0.23722569,  1.54396395],
           [-1.24609914,  0.30094681],
           [-0.18745572,  1.04657197],
           [ 0.90607352,  0.96120285],
           [-2.02612   ,  0.44082817],
           [ 0.8762596 ,  1.00607109],
           [ 0.98791921,  1.02441508],
           [-0.65307666,  1.22493946],
           [ 0.94162298,  1.28044258],
           [ 0.8622878 ,  0.99707326],
           [-0.27590245,  1.1547649 ],
           [ 0.99268975,  1.02885589],
           [ 1.0635428 ,  1.03445117],
           [-2.1378345 ,  0.62797163],
           [-1.40559883,  0.26079323],
           [ 1.07732353,  1.01373432],
           [-1.74785838,  1.25425571],
           [-0.51461996,  1.2583831 ],
           [ 1.02632384,  1.00203908],
           [ 0.84413823,  2.99872324],
           [ 1.10319604,  0.9615482 ],
           [ 0.95870127,  1.0461775 ],
           [-1.61872726,  0.55348188],
           [ 1.22219183,  1.00893646],
           [-0.04807925,  1.69061295],
           [-3.86851327, -0.36829707],
           [-0.84318558,  0.71791949],
           [ 0.95549697,  1.02457587],
           [ 0.15484069,  0.80992914],
           [ 1.1947279 ,  1.02301068],
           [-0.88323476,  1.52212056],
           [ 0.82715121,  0.99856576],
           [-0.97808876,  2.01262021],
           [-1.66906556,  0.70668215],
           [ 1.29672679,  0.64929896],
           [-0.45096669,  1.88364922],
           [-2.70110985,  0.36698604],
           [ 1.0795718 ,  1.02443886],
           [ 0.99150574,  0.98348741],
           [-0.65205587,  1.86131659],
           [-0.56754302,  1.87827013],
           [ 1.12356817,  1.06645171],
           [-2.72752499,  0.43018586],
           [-2.74061782, -0.08021407],
           [-0.3200331 ,  1.09683115],
           [ 1.0768664 ,  1.0085724 ],
           [-3.6325113 ,  0.67221516],
           [ 0.25830215,  0.79172286],
           [ 1.07796662,  1.00493526],
           [ 0.89606453,  0.98028498],
           [-0.94518278,  1.52377526],
           [ 0.90935946,  0.90695147],
           [ 1.0148515 ,  1.06783713],
           [ 1.16686534,  0.99312304],
           [-1.31640844,  0.32636521],
           [-1.39485695,  0.47605367],
           [-0.50763796,  2.04039346],
           [-0.58489137,  1.16215935],
           [-1.21643673,  1.16555051],
           [-2.9813908 , -0.02123246],
           [ 1.05056765,  1.0129612 ],
           [ 1.01961575,  1.03539024],
           [ 1.01227271,  0.96751672],
           [ 0.12444867,  1.38342266],
           [ 0.99713663,  0.96095512],
           [ 0.98185855,  0.9941474 ],
           [ 0.92998157,  1.03644759],
           [-0.18646788,  2.02399395],
           [-1.79776907,  0.97067984],
           [-3.23433111,  0.54897531],
           [-2.18617596,  0.33414794],
           [ 1.16844027,  1.01821873],
           [ 1.0428281 ,  1.01154471],
           [ 0.9159169 ,  1.02463567],
           [-1.3578118 ,  0.67183832],
           [-0.58824562,  1.08975919],
           [ 1.01775857,  1.00733938],
           [ 1.14847576,  1.01783862],
           [-1.1115874 ,  0.42278247],
           [ 0.84772713,  0.99733494],
           [ 1.00417018,  0.93763177],
           [ 0.56134549,  1.20390517]]),
    array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
           0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
           0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
           0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
           0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,
           1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
           1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,
           1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,
           1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
           0, 1]))

3. ニューラルネットワークモデルを構築する

# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim = 2, units = 1))
model.add(Activation('sigmoid'))
# 选定loss函数和优化器
model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

4. トレーニングモデル

# 训练过程
print("Training ----------")
for step in range(501):
    cost = model.train_on_batch(X, y)
    if step % 50 == 0:
        print("After %d trainings, the cost: %f" % (step, cost))
Training ----------
After 0 trainings, the cost: 0.370295
After 50 trainings, the cost: 0.349558
After 100 trainings, the cost: 0.331982
After 150 trainings, the cost: 0.316872
After 200 trainings, the cost: 0.303725
After 250 trainings, the cost: 0.292170
After 300 trainings, the cost: 0.281923
After 350 trainings, the cost: 0.272767
After 400 trainings, the cost: 0.264532
After 450 trainings, the cost: 0.257079
After 500 trainings, the cost: 0.250299

5. モデルをテストする

# 测试过程
print("Testing ----------")
cost = model.evaluate(X, y, batch_size = 40)
print("test cost:", cost)
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Testing ----------
200/200 [==============================] - 0s 53us/step
test cost: 0.25016908943653104
Weights =  [[-1.7198342 ]
 [-0.18482684]] 
biases =  [0.47288144]

6. 解析モデル

# 将训练结果绘出
Y_pred = model.predict(X)
# 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1
Y_pred = (Y_pred*2).astype('int')  
# 绘制散点图 参数:x横轴 y纵轴
plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0])
plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y)
plt.show()

1

添付ファイル: 一連の記事

シリアルナンバー 記事ディレクトリ 直接リンク
1 ボストンの住宅価格予測 https://want595.blog.csdn.net/article/details/132181950
2 アヤメのデータセット分析 https://want595.blog.csdn.net/article/details/132182057
3 特徴処理 https://want595.blog.csdn.net/article/details/132182165
4 相互検証 https://want595.blog.csdn.net/article/details/132182238
5 ニューラルネットワークの構築例 https://want595.blog.csdn.net/article/details/132182341
6 TensorFlow を使用した完全な線形回帰 https://want595.blog.csdn.net/article/details/132182417
7 TensorFlow を使用した完全なロジスティック回帰 https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard のケース https://want595.blog.csdn.net/article/details/132182584
9 Keras を使用した完全な線形回帰 https://want595.blog.csdn.net/article/details/132182723
10 Keras を使用した完全なロジスティック回帰 https://want595.blog.csdn.net/article/details/132182795
11 Keras の事前トレーニング済みモデルを使用した完全な猫と犬の認識 https://want595.blog.csdn.net/article/details/132243928
12 PyTorch を使用したモデルのトレーニング https://want595.blog.csdn.net/article/details/132243989
13 ドロップアウトを使用してオーバーフィッティングを抑制する https://want595.blog.csdn.net/article/details/132244111
14 CNN を使用して MNIST 手書き認識を完了する (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 CNN を使用して MNIST 手書き認識を完了する (Keras) https://want595.blog.csdn.net/article/details/132244552
16 CNN を使用して MNIST 手書き認識を完了する (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 GAN を使用して手書き数字サンプルを生成する https://want595.blog.csdn.net/article/details/132244764
18 自然言語処理 https://want595.blog.csdn.net/article/details/132276591

おすすめ

転載: blog.csdn.net/m0_68111267/article/details/132182795