[Deep Learning] Experiment 10 Use Keras to complete logistic regression

Complete logistic regression using Keras

Keras is an open source deep learning framework that can efficiently implement neural networks and deep learning models. Developed by Francois Chollet at New York University, it aims to provide an easy-to-use, high-level API so that developers can quickly build models, saving time and effort. Keras is compatible with various underlying deep learning frameworks, such as TensorFlow, Theano, and CNTK. It has become one of the most popular frameworks in the field of deep learning because of its ease of use and flexibility.

Keras is designed to make deep learning easier and faster from data to model. When using Keras for deep learning, you don’t need to write multiple lines of code to define hyperparameters such as neural layers, activation functions, optimizers, and loss functions, just one line of code. In addition, Keras also provides a wealth of pre-trained models that can be used to handle tasks such as image classification, natural language processing, text classification, and sequence analysis, thus greatly reducing the development and training time of deep learning models.

Keras also has the following features:

  1. Simple and easy to use: Keras is written in Python and provides a simple API interface, allowing users to pay more attention to the design and adjustment of the model.

  2. Easy to expand: Keras is compatible with a variety of deep learning frameworks, such as TensorFlow, Theano, and CNTK, and can use their computing power for efficient training and inference.

  3. Quick implementation: Keras provides a variety of pre-trained models, eliminating the need to develop models from scratch and quickly building high-quality deep learning models.

  4. Supports multiple languages: Keras not only supports the Python programming language but also other programming languages ​​such as R and Java.

  5. Open source community: Keras has a huge open source community on GitHub, with rich tutorials and examples for developers to learn and use better.

In short, Keras is an easy-to-use, efficient framework for implementing deep learning models, which can greatly improve the efficiency of development and implementation of deep learning models.

1. Import the Keras library

# 导入相关库
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from sklearn import datasets

import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
Using TensorFlow backend.

2. Generate data set

# 生成样本数据集,两个特征列,两个分类二分类不需要onehot编码,直接将类别转换为0和1,分别代表正样本的概率。
X, y = datasets.make_classification(n_samples = 200, n_features = 2, n_informative = 2, n_redundant = 0,
                                   n_repeated = 0, n_classes = 2, n_clusters_per_class = 1)
X, y
   (array([[ 0.26364611,  0.77250816],
           [ 0.91698377,  0.9802208 ],
           [ 0.82634329,  0.9821341 ],
           [-0.83833456,  0.88223515],
           [ 1.11509338,  0.98632275],
           [ 1.04196821,  0.97892474],
           [ 0.77695264,  1.06320914],
           [-2.16804253,  0.15267335],
           [-1.96973867,  0.99244728],
           [-1.35368845,  1.25840447],
           [-0.52455148,  2.2351536 ],
           [ 1.08554563,  1.03795405],
           [ 0.88261697,  0.97793289],
           [-1.03718795,  0.53830131],
           [ 0.94628633,  0.96289949],
           [ 1.16190683,  1.01806263],
           [-2.07795249,  0.32376505],
           [ 0.9370119 ,  1.01060097],
           [ 0.92750449,  0.98713143],
           [-0.35800128,  1.4498587 ],
           [-0.96709704,  1.77632874],
           [-0.55995817,  1.58782776],
           [ 0.88919948,  1.00133032],
           [ 1.16465115,  1.05117935],
           [-1.6969619 ,  1.80088135],
           [ 1.06292602,  1.04594288],
           [-0.07792111,  0.98391779],
           [-1.05188451,  1.26871626],
           [-0.83494005,  0.93958161],
           [ 1.10371115,  1.03558148],
           [ 0.98674372,  1.04567265],
           [-1.08345028,  1.18601788],
           [-2.06487683,  0.17118219],
           [ 1.02734931,  0.99326938],
           [-0.11345441,  1.08515199],
           [ 0.97705823,  1.01751506],
           [-0.10872522,  0.91580496],
           [-1.27087508, -0.19146954],
           [ 0.87616438,  0.97685435],
           [ 0.89526079,  0.98651642],
           [ 0.96521071,  1.0206381 ],
           [ 1.0530243 ,  0.93365071],
           [ 0.994778  ,  0.99724912],
           [ 0.98176246,  1.03168734],
           [ 0.74458014,  0.97066564],
           [ 0.91748012,  0.9524803 ],
           [-1.92749946,  0.07784549],
           [ 0.7790389 ,  0.95517882],
           [ 0.11824333,  1.81065221],
           [ 0.97490265,  0.95326328],
           [ 1.00355225,  0.96521073],
           [ 1.08398178,  0.97814922],
           [-1.0749128 ,  1.77825305],
           [ 0.74886096,  1.39448605],
           [-0.1950267 ,  1.57178284],
           [ 1.069671  ,  0.97202065],
           [ 0.85757149,  1.01910676],
           [-1.02014343,  1.14016873],
           [-1.25252256,  0.02906454],
           [ 0.93948239,  1.44153932],
           [ 1.28777891,  1.00133477],
           [-1.7010408 ,  0.0821629 ],
           [ 0.8390028 ,  0.97712472],
           [ 0.99480479,  1.05717262],
           [ 1.20707509,  0.97462669],
           [-2.18786288,  1.4515569 ],
           [ 1.16027197,  1.09086817],
           [ 1.02771087,  0.9907291 ],
           [ 0.71829704,  0.98817911],
           [ 0.88605935,  0.99158972],
           [ 1.03589316,  0.99557438],
           [ 1.15489923,  0.95378093],
           [ 1.0668616 ,  0.99316509],
           [ 1.04848333,  1.09471239],
           [-1.05108888, -0.071106  ],
           [-1.19977682,  1.49257613],
           [ 1.12232276,  0.99293853],
           [-0.36977293,  1.59581   ],
           [-0.27363841,  1.46272407],
           [ 1.18075342,  0.95907983],
           [ 1.01486256,  0.97501177],
           [-0.41533403,  1.72366429],
           [-0.18337732,  2.26674615],
           [ 1.06777804,  1.00982417],
           [ 1.17411206,  0.98088369],
           [ 0.95355889,  1.05238272],
           [-0.39459255,  1.97600217],
           [ 0.90103447,  0.94080238],
           [ 0.87268023,  1.00348657],
           [-1.93323667,  1.04826094],
           [ 0.10460058,  1.16348717],
           [-1.85815599,  1.32669461],
           [ 0.90426972,  0.97521677],
           [-0.58409513,  0.9870014 ],
           [-1.74011619, -0.21416096],
           [-1.51931589,  0.34938829],
           [ 1.02631005,  0.99378866],
           [ 1.02869184,  0.99995857],
           [ 0.79862419,  1.00291807],
           [-1.34714457,  0.78937109],
           [-2.54273315,  0.96748855],
           [-1.86729291,  0.37250653],
           [-0.89843699,  0.43898384],
           [-1.83077543,  0.43636701],
           [-0.89141966,  1.57275938],
           [-0.96662858,  0.8196104 ],
           [ 0.87417528,  1.00989496],
           [ 0.93997582,  0.95616278],
           [-1.85338565,  1.00940185],
           [ 0.89565224,  0.95460192],
           [-0.76327569,  0.93526008],
           [-1.78345269,  1.53378105],
           [ 0.77408528,  1.01387371],
           [-1.47669576,  1.43472266],
           [ 1.19417792,  1.0440538 ],
           [ 1.15595665,  0.96823244],
           [ 0.84068935,  1.01792225],
           [ 1.11747629,  1.05722511],
           [ 0.23722569,  1.54396395],
           [-1.24609914,  0.30094681],
           [-0.18745572,  1.04657197],
           [ 0.90607352,  0.96120285],
           [-2.02612   ,  0.44082817],
           [ 0.8762596 ,  1.00607109],
           [ 0.98791921,  1.02441508],
           [-0.65307666,  1.22493946],
           [ 0.94162298,  1.28044258],
           [ 0.8622878 ,  0.99707326],
           [-0.27590245,  1.1547649 ],
           [ 0.99268975,  1.02885589],
           [ 1.0635428 ,  1.03445117],
           [-2.1378345 ,  0.62797163],
           [-1.40559883,  0.26079323],
           [ 1.07732353,  1.01373432],
           [-1.74785838,  1.25425571],
           [-0.51461996,  1.2583831 ],
           [ 1.02632384,  1.00203908],
           [ 0.84413823,  2.99872324],
           [ 1.10319604,  0.9615482 ],
           [ 0.95870127,  1.0461775 ],
           [-1.61872726,  0.55348188],
           [ 1.22219183,  1.00893646],
           [-0.04807925,  1.69061295],
           [-3.86851327, -0.36829707],
           [-0.84318558,  0.71791949],
           [ 0.95549697,  1.02457587],
           [ 0.15484069,  0.80992914],
           [ 1.1947279 ,  1.02301068],
           [-0.88323476,  1.52212056],
           [ 0.82715121,  0.99856576],
           [-0.97808876,  2.01262021],
           [-1.66906556,  0.70668215],
           [ 1.29672679,  0.64929896],
           [-0.45096669,  1.88364922],
           [-2.70110985,  0.36698604],
           [ 1.0795718 ,  1.02443886],
           [ 0.99150574,  0.98348741],
           [-0.65205587,  1.86131659],
           [-0.56754302,  1.87827013],
           [ 1.12356817,  1.06645171],
           [-2.72752499,  0.43018586],
           [-2.74061782, -0.08021407],
           [-0.3200331 ,  1.09683115],
           [ 1.0768664 ,  1.0085724 ],
           [-3.6325113 ,  0.67221516],
           [ 0.25830215,  0.79172286],
           [ 1.07796662,  1.00493526],
           [ 0.89606453,  0.98028498],
           [-0.94518278,  1.52377526],
           [ 0.90935946,  0.90695147],
           [ 1.0148515 ,  1.06783713],
           [ 1.16686534,  0.99312304],
           [-1.31640844,  0.32636521],
           [-1.39485695,  0.47605367],
           [-0.50763796,  2.04039346],
           [-0.58489137,  1.16215935],
           [-1.21643673,  1.16555051],
           [-2.9813908 , -0.02123246],
           [ 1.05056765,  1.0129612 ],
           [ 1.01961575,  1.03539024],
           [ 1.01227271,  0.96751672],
           [ 0.12444867,  1.38342266],
           [ 0.99713663,  0.96095512],
           [ 0.98185855,  0.9941474 ],
           [ 0.92998157,  1.03644759],
           [-0.18646788,  2.02399395],
           [-1.79776907,  0.97067984],
           [-3.23433111,  0.54897531],
           [-2.18617596,  0.33414794],
           [ 1.16844027,  1.01821873],
           [ 1.0428281 ,  1.01154471],
           [ 0.9159169 ,  1.02463567],
           [-1.3578118 ,  0.67183832],
           [-0.58824562,  1.08975919],
           [ 1.01775857,  1.00733938],
           [ 1.14847576,  1.01783862],
           [-1.1115874 ,  0.42278247],
           [ 0.84772713,  0.99733494],
           [ 1.00417018,  0.93763177],
           [ 0.56134549,  1.20390517]]),
    array([1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
           0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
           0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
           0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
           0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0,
           1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
           1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1,
           1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1,
           1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
           0, 1]))

3. Construct a neural network model

# 构建神经网络模型
model = Sequential()
model.add(Dense(input_dim = 2, units = 1))
model.add(Activation('sigmoid'))
# 选定loss函数和优化器
model.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

4. Training model

# 训练过程
print("Training ----------")
for step in range(501):
    cost = model.train_on_batch(X, y)
    if step % 50 == 0:
        print("After %d trainings, the cost: %f" % (step, cost))
Training ----------
After 0 trainings, the cost: 0.370295
After 50 trainings, the cost: 0.349558
After 100 trainings, the cost: 0.331982
After 150 trainings, the cost: 0.316872
After 200 trainings, the cost: 0.303725
After 250 trainings, the cost: 0.292170
After 300 trainings, the cost: 0.281923
After 350 trainings, the cost: 0.272767
After 400 trainings, the cost: 0.264532
After 450 trainings, the cost: 0.257079
After 500 trainings, the cost: 0.250299

5. Test the model

# 测试过程
print("Testing ----------")
cost = model.evaluate(X, y, batch_size = 40)
print("test cost:", cost)
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Testing ----------
200/200 [==============================] - 0s 53us/step
test cost: 0.25016908943653104
Weights =  [[-1.7198342 ]
 [-0.18482684]] 
biases =  [0.47288144]

6. Analysis model

# 将训练结果绘出
Y_pred = model.predict(X)
# 将概率转化为类标号,概率在0-0.5时,转为0,概率在0.5-1时转为1
Y_pred = (Y_pred*2).astype('int')  
# 绘制散点图 参数:x横轴 y纵轴
plt.subplot(2,1,1).scatter(X[:,0], X[:,1], c=Y_pred[:,0])
plt.subplot(2,1,2).scatter(X[:,0], X[:,1], c=y)
plt.show()

1

Attachment: series of articles

serial number Article directory direct link
1 Boston house price forecast https://want595.blog.csdn.net/article/details/132181950
2 Iris dataset analysis https://want595.blog.csdn.net/article/details/132182057
3 Feature processing https://want595.blog.csdn.net/article/details/132182165
4 Cross-validation https://want595.blog.csdn.net/article/details/132182238
5 Constructing a Neural Network Example https://want595.blog.csdn.net/article/details/132182341
6 Complete linear regression using TensorFlow https://want595.blog.csdn.net/article/details/132182417
7 Complete logistic regression using TensorFlow https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard case https://want595.blog.csdn.net/article/details/132182584
9 Complete linear regression using Keras https://want595.blog.csdn.net/article/details/132182723
10 Complete logistic regression using Keras https://want595.blog.csdn.net/article/details/132182795
11 Complete cat and dog recognition using Keras pre-trained model https://want595.blog.csdn.net/article/details/132243928
12 Training models using PyTorch https://want595.blog.csdn.net/article/details/132243989
13 Use Dropout to suppress overfitting https://want595.blog.csdn.net/article/details/132244111
14 Using CNN to complete MNIST handwriting recognition (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 Using CNN to complete MNIST handwriting recognition (Keras) https://want595.blog.csdn.net/article/details/132244552
16 Using CNN to complete MNIST handwriting recognition (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 Using GAN to generate handwritten digit samples https://want595.blog.csdn.net/article/details/132244764
18 natural language processing https://want595.blog.csdn.net/article/details/132276591

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132182795