[Deep Learning] Experiment 09 Use Keras to complete linear regression

Complete linear regression using Keras

Keras is a Python-based deep learning framework with Tensorflow, Theano and CNTK as backends. It is developed and maintained by François Chollet. Its goal is to make the implementation of deep learning models fast and simple. It is designed to be user-friendly, extensible, easy to debug and experiment.

Keras provides a series of high-level APIs and convenient tools that allow users to quickly build and train deep learning models without paying attention to the underlying details. Keras supports various types of network structures, including convolutional neural networks, recurrent neural networks, autoencoders, etc., and can be easily trained and tested on different data sets.

The main features of Keras are:

  1. Easy to use, quick to get started: Keras provides a simple and easy-to-use API, and users can implement complex deep learning models with just a few lines of code.

  2. Supports multiple backends: Keras can use Tensorflow, Theano and CNTK as backends, and users can choose the appropriate backend according to their needs.

  3. Highly extensible: Keras provides a modular API, and users can add custom layers and functions as needed, as well as modify existing code.

  4. Convenient debugging and experimentation: Keras provides real-time visualization tools to facilitate users to view the training status and test results of the model, and supports various callback functions, such as early stopping, learning rate adjustment, etc.

  5. Supports GPU acceleration: Keras can use GPUs for calculations to accelerate the training and inference process of deep learning models.

In short, Keras is an excellent deep learning framework, which makes the construction and training of deep learning models easier and faster, and can help users focus more on the design and application of the model.

1. Import the Keras library

import warnings
warnings.filterwarnings("ignore")

import numpy as np
np.random.seed(1337)

from keras.models import Sequential
from keras.layers import Dense
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
Using TensorFlow backend.

2. Create a data set

# 创建数据集
# 在[-1,1]的区间内等间隔创建200个样本数
X = np.linspace(-1, 1, 200)
X
   array([-1.        , -0.98994975, -0.9798995 , -0.96984925, -0.95979899,
          -0.94974874, -0.93969849, -0.92964824, -0.91959799, -0.90954774,
          -0.89949749, -0.88944724, -0.87939698, -0.86934673, -0.85929648,
          -0.84924623, -0.83919598, -0.82914573, -0.81909548, -0.80904523,
          -0.79899497, -0.78894472, -0.77889447, -0.76884422, -0.75879397,
          -0.74874372, -0.73869347, -0.72864322, -0.71859296, -0.70854271,
          -0.69849246, -0.68844221, -0.67839196, -0.66834171, -0.65829146,
          -0.64824121, -0.63819095, -0.6281407 , -0.61809045, -0.6080402 ,
          -0.59798995, -0.5879397 , -0.57788945, -0.5678392 , -0.55778894,
          -0.54773869, -0.53768844, -0.52763819, -0.51758794, -0.50753769,
          -0.49748744, -0.48743719, -0.47738693, -0.46733668, -0.45728643,
          -0.44723618, -0.43718593, -0.42713568, -0.41708543, -0.40703518,
          -0.39698492, -0.38693467, -0.37688442, -0.36683417, -0.35678392,
          -0.34673367, -0.33668342, -0.32663317, -0.31658291, -0.30653266,
          -0.29648241, -0.28643216, -0.27638191, -0.26633166, -0.25628141,
          -0.24623116, -0.2361809 , -0.22613065, -0.2160804 , -0.20603015,
          -0.1959799 , -0.18592965, -0.1758794 , -0.16582915, -0.15577889,
          -0.14572864, -0.13567839, -0.12562814, -0.11557789, -0.10552764,
          -0.09547739, -0.08542714, -0.07537688, -0.06532663, -0.05527638,
          -0.04522613, -0.03517588, -0.02512563, -0.01507538, -0.00502513,
           0.00502513,  0.01507538,  0.02512563,  0.03517588,  0.04522613,
           0.05527638,  0.06532663,  0.07537688,  0.08542714,  0.09547739,
           0.10552764,  0.11557789,  0.12562814,  0.13567839,  0.14572864,
           0.15577889,  0.16582915,  0.1758794 ,  0.18592965,  0.1959799 ,
           0.20603015,  0.2160804 ,  0.22613065,  0.2361809 ,  0.24623116,
           0.25628141,  0.26633166,  0.27638191,  0.28643216,  0.29648241,
           0.30653266,  0.31658291,  0.32663317,  0.33668342,  0.34673367,
           0.35678392,  0.36683417,  0.37688442,  0.38693467,  0.39698492,
           0.40703518,  0.41708543,  0.42713568,  0.43718593,  0.44723618,
           0.45728643,  0.46733668,  0.47738693,  0.48743719,  0.49748744,
           0.50753769,  0.51758794,  0.52763819,  0.53768844,  0.54773869,
           0.55778894,  0.5678392 ,  0.57788945,  0.5879397 ,  0.59798995,
           0.6080402 ,  0.61809045,  0.6281407 ,  0.63819095,  0.64824121,
           0.65829146,  0.66834171,  0.67839196,  0.68844221,  0.69849246,
           0.70854271,  0.71859296,  0.72864322,  0.73869347,  0.74874372,
           0.75879397,  0.76884422,  0.77889447,  0.78894472,  0.79899497,
           0.80904523,  0.81909548,  0.82914573,  0.83919598,  0.84924623,
           0.85929648,  0.86934673,  0.87939698,  0.88944724,  0.89949749,
           0.90954774,  0.91959799,  0.92964824,  0.93969849,  0.94974874,
           0.95979899,  0.96984925,  0.9798995 ,  0.98994975,  1.        ])
# 将数据集随机化
np.random.shuffle(X)
X
   array([-0.70854271,  0.1758794 , -0.30653266,  0.74874372, -0.02512563,
           0.33668342, -0.85929648,  0.01507538, -0.13567839,  0.72864322,
           0.24623116, -0.74874372, -0.78894472,  0.50753769,  0.03517588,
           0.35678392, -0.55778894,  0.2361809 , -0.25628141, -0.44723618,
           0.2160804 , -0.43718593, -0.64824121,  0.69849246, -0.03517588,
          -0.45728643,  0.86934673,  0.73869347,  0.53768844, -0.67839196,
          -0.75879397,  0.55778894,  0.28643216, -0.05527638, -0.86934673,
           0.1959799 , -0.57788945, -0.9798995 , -0.6080402 , -0.63819095,
           0.84924623,  0.41708543,  0.13567839,  0.79899497, -0.47738693,
           0.46733668,  0.59798995, -0.80904523, -0.98994975, -0.36683417,
          -0.5678392 , -0.00502513, -0.53768844, -0.37688442, -0.65829146,
          -0.1959799 ,  0.06532663,  0.44723618, -0.01507538, -0.6281407 ,
           0.02512563, -0.71859296, -0.14572864, -0.46733668,  0.07537688,
           0.85929648,  0.76884422,  0.40703518, -0.68844221,  0.68844221,
          -0.29648241,  0.66834171, -0.95979899, -0.33668342,  0.26633166,
          -0.82914573,  1.        , -0.5879397 , -0.69849246, -0.20603015,
           0.63819095, -0.88944724, -0.40703518, -0.32663317,  0.15577889,
          -0.41708543,  0.10552764,  0.20603015, -0.04522613,  0.00502513,
          -0.31658291,  0.43718593,  0.42713568,  0.45728643, -0.59798995,
          -0.66834171,  0.83919598,  0.75879397, -0.24623116,  0.71859296,
          -0.92964824,  0.39698492,  0.61809045, -0.84924623, -0.87939698,
          -0.96984925,  0.87939698,  0.6281407 ,  0.25628141,  0.27638191,
           0.12562814,  0.09547739, -0.89949749,  0.80904523, -0.16582915,
          -0.12562814,  0.30653266,  0.49748744,  0.5879397 , -0.51758794,
          -0.10552764,  0.54773869, -0.94974874,  0.92964824,  0.16582915,
          -0.83919598, -0.35678392, -0.48743719,  0.08542714, -0.61809045,
           0.18592965,  0.57788945,  0.65829146,  0.38693467,  0.91959799,
          -0.26633166, -0.50753769, -1.        , -0.54773869,  0.6080402 ,
          -0.49748744, -0.22613065,  0.9798995 ,  0.98994975,  0.5678392 ,
           0.32663317,  0.64824121, -0.52763819,  0.36683417,  0.81909548,
          -0.11557789,  0.31658291, -0.2160804 ,  0.95979899,  0.77889447,
          -0.73869347, -0.81909548, -0.79899497,  0.78894472,  0.88944724,
          -0.2361809 ,  0.37688442,  0.70854271,  0.22613065, -0.28643216,
          -0.38693467,  0.90954774, -0.91959799,  0.48743719, -0.42713568,
          -0.08542714,  0.11557789, -0.18592965,  0.47738693, -0.39698492,
          -0.34673367,  0.04522613,  0.05527638,  0.93969849, -0.77889447,
          -0.93969849, -0.06532663, -0.72864322,  0.29648241,  0.52763819,
          -0.76884422,  0.94974874,  0.82914573,  0.34673367, -0.90954774,
          -0.27638191, -0.15577889, -0.1758794 ,  0.14572864, -0.09547739,
           0.96984925,  0.67839196, -0.07537688,  0.89949749,  0.51758794])
# 假设真实模型为:Y=0.5X+2
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,))
Y
   array([1.66851812, 2.12220988, 1.91611873, 2.38979647, 1.96473269,
          2.11662688, 1.58217043, 2.05326658, 1.95885373, 2.4277956 ,
          2.13544689, 1.68732448, 1.66384243, 2.2702853 , 2.03148986,
          2.14968674, 1.76442495, 2.10802586, 1.93269542, 1.81936289,
          2.15190248, 1.83941395, 1.71399197, 2.21820555, 1.97918099,
          1.79781646, 2.43645587, 2.31211201, 2.21764353, 1.71912829,
          1.64285239, 2.2663785 , 2.11081029, 2.09338152, 1.5614153 ,
          2.19655545, 1.72824772, 1.56444412, 1.72673075, 1.67311017,
          2.39817488, 2.12624087, 2.07791136, 2.40515644, 1.80701389,
          2.16050089, 2.30373845, 1.57656517, 1.52482139, 1.7639545 ,
          1.76787463, 2.01204511, 1.74877623, 1.86751173, 1.67509082,
          1.95941218, 2.0126989 , 2.31574759, 2.04672223, 1.73762178,
          1.97249596, 1.65257838, 1.98435822, 1.74193776, 2.05272917,
          2.41693508, 2.37609913, 2.24686996, 1.61790402, 2.37607665,
          1.82677368, 2.29512653, 1.52756173, 1.79404414, 2.08314   ,
          1.5209276 , 2.48034115, 1.7821867 , 1.60377021, 1.82345627,
          2.23840132, 1.50174227, 1.85127905, 1.92372432, 1.95433662,
          1.8146093 , 1.96513404, 2.0227501 , 1.97564664, 2.09893966,
          1.95392005, 2.2089975 , 2.26074219, 2.24742979, 1.75936195,
          1.69145596, 2.46801952, 2.40938521, 1.98369075, 2.37509171,
          1.53026033, 2.24305926, 2.33309562, 1.49913881, 1.48743005,
          1.54075518, 2.33130062, 2.37463005, 2.19387461, 2.20970603,
          2.04719149, 2.04105128, 1.48410805, 2.34714158, 1.95061571,
          1.89473245, 2.26596278, 2.22430597, 2.29984983, 1.7894671 ,
          1.85995514, 2.31688729, 1.53417344, 2.39777465, 2.12853793,
          1.47736812, 1.90180229, 1.73086567, 2.03772387, 1.67243511,
          2.10115733, 2.26944612, 2.37404859, 2.22042332, 2.4948031 ,
          1.80153666, 1.72069013, 1.44829544, 1.77678155, 2.24291992,
          1.73557503, 1.79249737, 2.52580388, 2.46810975, 2.34211232,
          2.22144569, 2.31945172, 1.72814133, 2.17318812, 2.43560932,
          1.9662451 , 2.14319385, 1.83150682, 2.48805089, 2.28374904,
          1.63645718, 1.57901687, 1.61041853, 2.40884706, 2.37339631,
          1.90728817, 2.09065413, 2.36836694, 2.05400262, 1.87764304,
          1.83547711, 2.45064964, 1.46324772, 2.2429919 , 1.75954149,
          1.97326923, 2.08379661, 2.04616096, 2.3161197 , 1.81470671,
          1.8188581 , 2.11349671, 2.05477704, 2.39622142, 1.61281075,
          1.56914576, 1.96947616, 1.56645219, 2.08002605, 2.2185357 ,
          1.54079134, 2.42384819, 2.41198434, 2.0570266 , 1.55142224,
          1.83396657, 1.92648666, 1.9143498 , 1.9372014 , 1.92794208,
          2.42698754, 2.29871021, 2.03266023, 2.42413239, 2.28286632])
# 绘制数据集(X, Y)
plt.scatter(X, Y)
plt.show()

1

3. Divide the data set

# 划分训练集和测试集
X_train, Y_train = X[:160], Y[:160]
X_test, Y_test = X[160:], Y[160:]

4. Construct a neural network model

# 定义一个model
# Keras有两种类型的模型,序列模型和函数式模型
# 比较常用的是Sequential,它是单输入单输出的
model = Sequential()

# 通过add()方法一层层添加模型
# Dense是全连接层,第一层需要定义输入
model.add(Dense(output_dim=1,input_dim=1))

# 定义完成模型就要训练了,不过训练之前我们需要指定一些训练参数
# 通过compile()方法选择损失函数和优化器
# 这里我们用均方差作为损失函数,随机梯度下降作为优化方法
model.compile(loss='mse', optimizer='sgd')

5. Training model

# 开始训练
print('Training ----------')

# Keras有很多开始训练的函数,这里用train_on_batch()
for step in range(301):
    cost = model.train_on_batch(X_train,Y_train)
    if step%100 == 0:
        print('train cost: ', cost)
Training ----------
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

train cost:  4.0225005
train cost:  0.073238626
train cost:  0.00386274
train cost:  0.002643449

6. Test the model

# 测试训练好的模型
print('Testing ----------')
cost = model.evaluate(X_test, Y_test, batch_size = 40)
print('test cost: ',cost)
Testing ----------
40/40 [==============================] - 0s 508us/step
test cost:  0.0031367032788693905

7. Analysis model

# 查看训练出的网络参数
# 由于我们网络只有一层,且每次训练的输入只有一个,输出只有一个
# 因此第一层训练出Y=WX+B这个模型,其中W,b为训练出的参数
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Weights =  [[0.4922711]] 
biases =  [1.9995022]
# 画出预测图
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

2

#使用r2 score评估准确度
pred_acc = r2_score(Y_test, Y_pred)
print('pred_acc',pred_acc)
pred_acc 0.9591211310535933
#保存模型
model.save('keras_linear.h5')

Attachment: series of articles

serial number Article directory direct link
1 Boston house price forecast https://want595.blog.csdn.net/article/details/132181950
2 Iris dataset analysis https://want595.blog.csdn.net/article/details/132182057
3 Feature processing https://want595.blog.csdn.net/article/details/132182165
4 Cross-validation https://want595.blog.csdn.net/article/details/132182238
5 Constructing a Neural Network Example https://want595.blog.csdn.net/article/details/132182341
6 Complete linear regression using TensorFlow https://want595.blog.csdn.net/article/details/132182417
7 Complete logistic regression using TensorFlow https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard case https://want595.blog.csdn.net/article/details/132182584
9 Complete linear regression using Keras https://want595.blog.csdn.net/article/details/132182723
10 Complete logistic regression using Keras https://want595.blog.csdn.net/article/details/132182795
11 Complete cat and dog recognition using Keras pre-trained model https://want595.blog.csdn.net/article/details/132243928
12 Training models using PyTorch https://want595.blog.csdn.net/article/details/132243989
13 Use Dropout to suppress overfitting https://want595.blog.csdn.net/article/details/132244111
14 Using CNN to complete MNIST handwriting recognition (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 Using CNN to complete MNIST handwriting recognition (Keras) https://want595.blog.csdn.net/article/details/132244552
16 Using CNN to complete MNIST handwriting recognition (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 Using GAN to generate handwritten digit samples https://want595.blog.csdn.net/article/details/132244764
18 natural language processing https://want595.blog.csdn.net/article/details/132276591

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132182723