[Deep Learning] Experiment 05 Constructing a Neural Network Example

construct neural network

Note: This code is used to train a neural network. The network fits y = x^2-0.5+noise. The structure of the neural network is that the input layer is one neuron, the hidden layer is ten neurons, and the output layer is one Neurons

1. Import related libraries

# 导入相关库
import tensorflow as tf  # 用来构造神经网络
import numpy as np  # 用来构造数据结构和处理数据模块

2. Define a layer

# 定义一个层
def add_layer(inputs, in_size, out_size, activation_function=None):
    # 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数
    # activation_function为激励函数
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    # 初始权重随机生成比较好,in_size,out_size为该权重维度
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    # 偏置
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    # matmul为矩阵里的函数相乘
    if activation_function is None:
        outputs = Wx_plus_b  # 如果激活函数为空,则不激活,保持数据
    else:
        outputs = activation_function(Wx_plus_b)
        # 如果激活函数不为空,则激活,并且返回激活后的值
    return outputs  # 返回激活后的值

3. Construct the dataset

# 构造一些样本,用来训练神经网络
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
# 值为(-1,1)之间的数,有300个
noise = np.random.normal(0, 0.05, x_data.shape)
x_data

array([[-1. ],
[-0.99331104],
[-0.98662207],
[-0.97993311],
[-0.97324415],
[-0.96655518],
[-0.95986622],
[-0.95317726],
[-0.94648829],
[-0.93979933],
[-0.93311037],
[-0.9264214 ],
[-0.91973244],
[-0.91304348],
[-0.90635452],
[-0.89966555],
[-0.89297659],
[-0.88628763],
[-0.87959866],
[-0.8729097 ],
[-0.86622074],
[-0.85953177],
[-0.85284281],
[-0.84615385],
[-0.83946488],
[-0.83277592],
[-0.82608696],
[-0.81939799],
[-0.81270903],
[-0.80602007],
[-0.7993311 ],
[-0.79264214],
[-0.78595318],
[-0.77926421],
[-0.77257525],
[-0.76588629],
[-0.75919732],
[-0.75250836],
[-0.7458194 ],
[-0.73913043],
[-0.73244147],
[-0.72575251],
[-0.71906355],
[-0.71237458],
[-0.70568562],
[-0.69899666],
[-0.69230769],
[-0.68561873],
[-0.67892977],
[-0.6722408 ],
[-0.66555184],
[-0.65886288],
[-0.65217391],
[-0.64548495],
[-0.63879599],
[-0.63210702],
[-0.62541806],
[-0.6187291 ],
[-0.61204013],
[-0.60535117],
[-0.59866221],
[-0.59197324],
[-0.58528428],
[-0.57859532],
[-0.57190635],
[-0.56521739],
[-0.55852843],
[-0.55183946],
[-0.5451505 ],
[-0.53846154],
[-0.53177258],
[-0.52508361],
[-0.51839465],
[-0.51170569],
[-0.50501672],
[-0.49832776],
[-0.4916388 ],
[-0.48494983],
[-0.47826087],
[-0.47157191],
[-0.46488294],
[-0.45819398],
[-0.45150502],
[-0.44481605],
[-0.43812709],
[-0.43143813],
[-0.42474916],
[-0.4180602 ],
[-0.41137124],
[-0.40468227],
[-0.39799331],
[-0.39130435],
[-0.38461538],
[-0.37792642],
[-0.37123746],
[-0.36454849],
[-0.35785953],
[-0.35117057],
[-0.34448161],
[-0.33779264],
[-0.33110368],
[-0.32441472],
[-0.31772575],
[-0.31103679],
[-0.30434783],
[-0.29765886],
[-0.2909699 ],
[-0.28428094],
[-0.27759197],
[-0.27090301],
[-0.26421405],
[-0.25752508],
[-0.25083612],
[-0.24414716],
[-0.23745819],
[-0.23076923],
[-0.22408027],
[-0.2173913 ],
[-0.21070234],
[-0.20401338],
[-0.19732441],
[-0.19063545],
[-0.18394649],
[-0.17725753],
[-0.17056856],
[-0.1638796 ],
[-0.15719064],
[-0.15050167],
[-0.14381271],
[-0.13712375],
[-0.13043478],
[-0.12374582],
[-0.11705686],
[-0.11036789],
[-0.10367893],
[-0.09698997],
[-0.090301 ],
[-0.08361204],
[-0.07692308],
[-0.07023411],
[-0.06354515],
[-0.05685619],
[-0.05016722],
[-0.04347826],
[-0.0367893 ],
[-0.03010033],
[-0.02341137],
[-0.01672241],
[-0.01003344],
[-0.00334448],
[ 0.00334448],
[ 0.01003344],
[ 0.01672241],
[ 0.02341137],
[ 0.03010033],
[ 0.0367893 ],
[ 0.04347826],
[ 0.05016722],
[ 0.05685619],
[ 0.06354515],
[ 0.07023411],
[ 0.07692308],
[ 0.08361204],
[ 0.090301 ],
[ 0.09698997],
[ 0.10367893],
[ 0.11036789],
[ 0.11705686],
[ 0.12374582],
[ 0.13043478],
[ 0.13712375],
[ 0.14381271],
[ 0.15050167],
[ 0.15719064],
[ 0.1638796 ],
[ 0.17056856],
[ 0.17725753],
[ 0.18394649],
[ 0.19063545],
[ 0.19732441],
[ 0.20401338],
[ 0.21070234],
[ 0.2173913 ],
[ 0.22408027],
[ 0.23076923],
[ 0.23745819],
[ 0.24414716],
[ 0.25083612],
[ 0.25752508],
[ 0.26421405],
[ 0.27090301],
[ 0.27759197],
[ 0.28428094],
[ 0.2909699 ],
[ 0.29765886],
[ 0.30434783],
[ 0.31103679],
[ 0.31772575],
[ 0.32441472],
[ 0.33110368],
[ 0.33779264],
[ 0.34448161],
[ 0.35117057],
[ 0.35785953],
[ 0.36454849],
[ 0.37123746],
[ 0.37792642],
[ 0.38461538],
[ 0.39130435],
[ 0.39799331],
[ 0.40468227],
[ 0.41137124],
[ 0.4180602 ],
[ 0.42474916],
[ 0.43143813],
[ 0.43812709],
[ 0.44481605],
[ 0.45150502],
[ 0.45819398],
[ 0.46488294],
[ 0.47157191],
[ 0.47826087],
[ 0.48494983],
[ 0.4916388 ],
[ 0.49832776],
[ 0.50501672],
[ 0.51170569],
[ 0.51839465],
[ 0.52508361],
[ 0.53177258],
[ 0.53846154],
[ 0.5451505 ],
[ 0.55183946],
[ 0.55852843],
[ 0.56521739],
[ 0.57190635],
[ 0.57859532],
[ 0.58528428],
[ 0.59197324],
[ 0.59866221],
[ 0.60535117],
[ 0.61204013],
[ 0.6187291 ],
[ 0.62541806],
[ 0.63210702],
[ 0.63879599],
[ 0.64548495],
[ 0.65217391],
[ 0.65886288],
[ 0.66555184],
[ 0.6722408 ],
[ 0.67892977],
[ 0.68561873],
[ 0.69230769],
[ 0.69899666],
[ 0.70568562],
[ 0.71237458],
[ 0.71906355],
[ 0.72575251],
[ 0.73244147],
[ 0.73913043],
[ 0.7458194 ],
[ 0.75250836],
[ 0.75919732],
[ 0.76588629],
[ 0.77257525],
[ 0.77926421],
[ 0.78595318],
[ 0.79264214],
[ 0.7993311 ],
[ 0.80602007],
[ 0.81270903],
[ 0.81939799],
[ 0.82608696],
[ 0.83277592],
[ 0.83946488],
[ 0.84615385],
[ 0.85284281],
[ 0.85953177],
[ 0.86622074],
[ 0.8729097 ],
[ 0.87959866],
[ 0.88628763],
[ 0.89297659],
[ 0.89966555],
[ 0.90635452],
[ 0.91304348],
[ 0.91973244],
[ 0.9264214 ],
[ 0.93311037],
[ 0.93979933],
[ 0.94648829],
[ 0.95317726],
[ 0.95986622],
[ 0.96655518],
[ 0.97324415],
[ 0.97993311],
[ 0.98662207],
[ 0.99331104],
[ 1. ]])

# 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样
y_data = np.square(x_data) - 0.5 + noise
# y的结构
y_data

array([[ 0.59535036],
[ 0.46017998],
[ 0.47144478],
[ 0.45083795],
[ 0.58438217],
[ 0.38570118],
[ 0.43550029],
[ 0.40597571],
[ 0.3357524 ],
[ 0.35784864],
[ 0.34530231],
[ 0.32509701],
[ 0.25554733],
[ 0.32300801],
[ 0.2299959 ],
[ 0.35472568],
[ 0.31227671],
[ 0.30385068],
[ 0.29413844],
[ 0.18437787],
[ 0.28132819],
[ 0.25605309],
[ 0.23126361],
[ 0.23492797],
[ 0.18381621],
[ 0.10392937],
[ 0.13415913],
[ 0.14043649],
[ 0.11756826],
[ 0.12142749],
[ 0.12400694],
[ 0.08926307],
[ 0.15581832],
[ 0.16541106],
[-0.02582895],
[ 0.05924725],
[-0.04037454],
[ 0.03799003],
[ 0.09030832],
[ 0.05984324],
[-0.06569464],
[ 0.07973773],
[ 0.04297837],
[ 0.05169557],
[-0.00096191],
[-0.02049573],
[-0.03125322],
[-0.04545588],
[-0.02168901],
[ 0.01657517],
[-0.04315181],
[-0.09123519],
[-0.03292835],
[-0.1110189 ],
[-0.08212792],
[-0.10089535],
[-0.17406672],
[-0.10380731],
[-0.10774072],
[-0.21283138],
[-0.09788435],
[-0.10196452],
[-0.16439081],
[-0.15431978],
[-0.17778307],
[-0.18428537],
[-0.17874028],
[-0.10490738],
[-0.25076832],
[-0.16078044],
[-0.21572183],
[-0.15624353],
[-0.19591988],
[-0.31560742],
[-0.29593726],
[-0.26686787],
[-0.2999804 ],
[-0.30631065],
[-0.35305224],
[-0.31295125],
[-0.22996255],
[-0.22837061],
[-0.27266253],
[-0.31290802],
[-0.37188479],
[-0.20765034],
[-0.33860431],
[-0.31135236],
[-0.25249981],
[-0.26041048],
[-0.31486205],
[-0.30253306],
[-0.41624795],
[-0.40053837],
[-0.29939676],
[-0.32615377],
[-0.37377787],
[-0.32222027],
[-0.3158838 ],
[-0.43880087],
[-0.37510637],
[-0.46702321],
[-0.27058091],
[-0.52885151],
[-0.4061462 ],
[-0.4486374 ],
[-0.37819628],
[-0.34701947],
[-0.32454364],
[-0.3901839 ],
[-0.43293107],
[-0.47881173],
[-0.45280819],
[-0.49676541],
[-0.48955669],
[-0.45898691],
[-0.37473462],
[-0.43801531],
[-0.44793655],
[-0.57343047],
[-0.45262969],
[-0.40719677],
[-0.45423461],
[-0.45053051],
[-0.51046881],
[-0.41584096],
[-0.53328545],
[-0.44766406],
[-0.50158463],
[-0.42676031],
[-0.50552613],
[-0.36832989],
[-0.48699296],
[-0.41614151],
[-0.6175621 ],
[-0.48304532],
[-0.46115021],
[-0.40948908],
[-0.42017024],
[-0.50411757],
[-0.44530626],
[-0.46895275],
[-0.52127771],
[-0.50064585],
[-0.42210169],
[-0.58582837],
[-0.52049198],
[-0.45332091],
[-0.53465815],
[-0.5385712 ],
[-0.5654201 ],
[-0.54471377],
[-0.48109194],
[-0.44565732],
[-0.48112022],
[-0.46471786],
[-0.5452149 ],
[-0.52115601],
[-0.50234928],
[-0.54885558],
[-0.5279981 ],
[-0.53893795],
[-0.44286416],
[-0.45371406],
[-0.44633111],
[-0.57535678],
[-0.62918947],
[-0.41877124],
[-0.56263956],
[-0.51201705],
[-0.35016007],
[-0.49188897],
[-0.55766056],
[-0.38963378],
[-0.5038024 ],
[-0.51949984],
[-0.45229896],
[-0.49193029],
[-0.53472883],
[-0.48957523],
[-0.35561181],
[-0.4622668 ],
[-0.39177781],
[-0.43448445],
[-0.49854629],
[-0.49843105],
[-0.47704375],
[-0.36618194],
[-0.45177012],
[-0.41497222],
[-0.42152064],
[-0.48996608],
[-0.43010878],
[-0.42599962],
[-0.2841197 ],
[-0.38992082],
[-0.43802592],
[-0.42448799],
[-0.29514676],
[-0.37154091],
[-0.25426219],
[-0.44610678],
[-0.37120566],
[-0.3531599 ],
[-0.34606119],
[-0.29637877],
[-0.3693284 ],
[-0.36651142],
[-0.30025118],
[-0.31443603],
[-0.40824064],
[-0.31734053],
[-0.40807378],
[-0.33792031],
[-0.22414921],
[-0.37707072],
[-0.26776417],
[-0.29152204],
[-0.34066934],
[-0.19037511],
[-0.23552614],
[-0.2144995 ],
[-0.27628531],
[-0.27329725],
[-0.23910513],
[-0.30009859],
[-0.30192088],
[-0.16403744],
[-0.32546893],
[-0.25686912],
[-0.12515146],
[-0.21483097],
[-0.12779443],
[-0.28748063],
[-0.23782354],
[-0.16024807],
[-0.19062672],
[-0.15066097],
[-0.19043274],
[-0.16583211],
[-0.11201314],
[-0.05612149],
[-0.00847256],
[-0.1429705 ],
[-0.09595988],
[-0.09583441],
[-0.01372838],
[-0.04818834],
[-0.11840653],
[ 0.02184166],
[-0.07153294],
[-0.11556547],
[-0.04731049],
[-0.10774914],
[-0.014642 ],
[-0.01470962],
[-0.03259555],
[-0.04194347],
[ 0.08987345],
[-0.02027899],
[ 0.02418433],
[ 0.04298611],
[ 0.04130101],
[ 0.18010436],
[ 0.15480307],
[ 0.02719993],
[ 0.11508363],
[ 0.04309794],
[ 0.14060578],
[ 0.09377926],
[ 0.13887198],
[ 0.16148276],
[ 0.11398259],
[ 0.27887578],
[ 0.22775177],
[ 0.20749998],
[ 0.22107721],
[ 0.20854961],
[ 0.25411644],
[ 0.26561906],
[ 0.27540788],
[ 0.26946028],
[ 0.2390275 ],
[ 0.26051795],
[ 0.34424064],
[ 0.3240088 ],
[ 0.38040554],
[ 0.35717078],
[ 0.31357911],
[ 0.43825368],
[ 0.35709739],
[ 0.48101049],
[ 0.36024364],
[ 0.43253108],
[ 0.39268334],
[ 0.41942572],
[ 0.41196584],
[ 0.54435941],
[ 0.49840622],
[ 0.51627957]])

4. Define the base model

# 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))

# 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

5. Variable initialization

# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
# 变量初始化
if int((tf.__version__).split('.')[1]) < 12:
    init = tf.initialize_all_variables()
else:
    init = tf.global_variables_initializer()
sess = tf.Session()  # 打开TensorFlow
sess.run(init)  # 执行变量初始化

6. Start training

for i in range(1000):  # 梯度下降迭代一千次
    # training
    sess.run(train_step, feed_dict={
    
    xs: x_data, ys: y_data})
    # 执行梯度下降算法,并且将样本喂给损失函数
    if i % 50 == 0:
        # 每50次迭代输出代价函数的值
        print(sess.run(loss, feed_dict={
    
    xs: x_data, ys: y_data}))
0.18214862
0.010138167
0.0071248626
0.0069830194
0.0068635535
0.0067452225
0.006626569
0.0065121166
0.0064035906
0.006295418
0.0061897114
0.0060903295
0.005990808
0.0058959606
0.0058057955
0.0057200184
0.005637601
0.0055605737
0.0054863705
0.005413457

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132182341