【ディープラーニング】実験05 ニューラルネットワークの構築例

ニューラルネットワークを構築する

注: このコードはニューラル ネットワークのトレーニングに使用されます。ネットワークは y = x^2-0.5+noise に適合します。ニューラル ネットワークの構造は、入力層が 1 つのニューロン、隠れ層が 10 個のニューロン、出力がレイヤーは 1 つのニューロンです

1. 関連ライブラリをインポートする

# 导入相关库
import tensorflow as tf  # 用来构造神经网络
import numpy as np  # 用来构造数据结构和处理数据模块

2. レイヤーを定義する

# 定义一个层
def add_layer(inputs, in_size, out_size, activation_function=None):
    # 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数
    # activation_function为激励函数
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    # 初始权重随机生成比较好,in_size,out_size为该权重维度
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    # 偏置
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    # matmul为矩阵里的函数相乘
    if activation_function is None:
        outputs = Wx_plus_b  # 如果激活函数为空,则不激活,保持数据
    else:
        outputs = activation_function(Wx_plus_b)
        # 如果激活函数不为空,则激活,并且返回激活后的值
    return outputs  # 返回激活后的值

3. データセットを構築する

# 构造一些样本,用来训练神经网络
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
# 值为(-1,1)之间的数,有300个
noise = np.random.normal(0, 0.05, x_data.shape)
x_data

array([[-1. ],
[-0.99331104],
[-0.98662207],
[-0.97993311], [
-0.97324415], [-0.96655518]
, [
-0.95986622
], [-0.95317726],
[-0.94648829],
[ -0.93979933]、
[-0.93311037]、
[-0.9264214 ]、
[-0.91973244]、[
-0.91304348]、[ - 0.90635452
] 、[-0.89966555]、[- 0.89297659]、[- 0.8862876] 3]、 [-0.87959866]、[ -0.8729097 ]、[-0.86622074]、[-0.85953177]、[-0.85284281]、[-0.84615385]、[ -0.83946488 ] 、[-0.83277592]、[- 0.82608696]、[-0.8193979] 9]、[-0.81270903]、[ -0.80602007]、















[-0.7993311 ]、
[-0.79264214]、
[-0.78595318]、
[-0.77926421]、
[-0.77257525]、 [-0.76588629
] 、[-
0.75919732 ]、[-
0.75250836]、
[-0.745819] 4 ]、
[-0.73913043]、
[-0.73244147]、
[-0.72575251]、
[-0.71906355]、
[-0.71237458]、
[-0.70568562]、
[
-
0.69899666]、[-0.69230769]、[-0.68561873]、
[-0.678929] 77]、
[-0.6722408]、
[-0.66555184]、
[-0.65886288]、
[-0.65217391]、
[-0.64548495]、
[-0.63879599]、[ -
0.63210702] 、[-0.62541806]、[-0.6187291 ]、[-0.612040] 13]、[-0.60535117]、




[-0.59866221]、
[-0.59197324]、
[-0.58528428]、
[-0.57859532]、
[-0.57190635] 、 [-0.56521739
]、[-0.55852843
]
、[-0.55183946]、 [-0.545150] 5 ]、 [-0.53846154]、[-0.53177258]、[-0.52508361]、[-0.51839465]、[-0.51170569]、[-0.50501672]、[ -0.49832776]、[ -0.4916388 ]、[-0.48494983]、[-0.478260] 87]、[-0.47157191]、[-0.46488294]、[-0.45819398]、[-0.45150502]、[-0.44481605]、[-0.43812709]、 [- 0.43143813] 、[ -0.42474916 ]、[ -0.4180602 ]、[-0.411371] 24]、[-0.40468227]、






















[-0.39799331]、
[-0.39130435]、
[-0.38461538]、
[-0.37792642]、
[-0.37123746]
[-0.36454849]、[-0.35785953]、[-0.35117057]、[-0.344481] 61]、
[
- 0.33779264]、[-0.33110368]、[-0.32441472]、[-0.31772575]、[-0.31103679]、[-0.30434783]、[-0.29765886] [- 0.2909699 ]、[ -0.28428094]、[-0.277591 97]、[-0.27090301]、[-0.26421405]、[-0.25752508]、[-0.25083612]、[-0.24414716]、[-0.23745819]、 [- 0.23076923] 、[- 0.22408027 ]、[-0.2173913 ]、[-0.210702] 34]、[-0.20401338]、






















[-0.19732441]、
[-0.19063545]、
[-0.18394649]、
[-0.17725753]、
[-0.17056856]
、 [-0.1638796 ]、[-
0.15719064 ]、[
-0.15050167]、
[-0.143812] 71]、
[-0.13712375]、
[-0.13043478]、
[-0.12374582]、
[-0.11705686]、
[-0.11036789]、
[-0.10367893]
[ - 0.09698997]、[-0.090301 ]、[ -0.08361204]、 [-0.0769230] 8]、 [-0.07023411]、[-0.06354515]、[-0.05685619]、[-0.05016722]、[-0.04347826]、[-0.0367893 ] 、 [-0.03010033]、[- 0.02341137 ]、[- 0.01672241]、 [-0.010033 ] 44]、[-0.00334448]、














[ 0.00334448]、
[ 0.01003344]、
[ 0.01672241]、
[ 0.02341137]、[ 0.03010033]、
[ 0.0367893 ]、[
0.04347826]

[ 0.05016722 ]、[ 0.056856 19]、 [ 0.06354515]、[ 0.07023411]
[ 0.07692308]、[ 0.08361204 ]、[ 0.090301 ]、[ 0.09698997]、[ 0.10367893]、[ 0.11036789]、[ 0.11705686]、[ 0.12374582]、[ 0.13043478]、[ 0.13712375] [ 0.1438 1271]、[ 0.15050167 ]、[0.15719064]、[0.1638796]、[ 0.17056856]、[ 0.17725753]、[ 0.18394649]、[ 0.19063545]、[ 0.19732441]、





















[ 0.20401338]、
[ 0.21070234]、
[ 0.2173913 ]、
[ 0.22408027]、[
0.23076923]、[ 0.23745819]、[
0.24414716]
、[ 0.25083612
] 、 [ 0.257525 08]、[ 0.26421405]、[ 0.27090301]
[ 0.27759197]、[ 0.28428094 ]、[ 0.2909699 ]、[ 0.29765886]、[ 0.30434783]、[ 0.31103679]、[ 0.31772575]、[ 0.32441472]、[ 0.33110368] 、[ 0.33779264 ]、 [ 0.344 48161]、[ 0.35117057]、 [ 0.35785953 ]、[0.36454849]、[ 0.37123746]、[ 0.37792642]、[ 0.38461538]、[ 0.39130435]、[ 0.39799331]、





















[ 0.40468227]、
[ 0.41137124]、
[ 0.4180602 ]、
[ 0.42474916]、[
0.43143813]、[ 0.43812709]、[
0.44481605]
、[
0.45150502 ]、[ 0.458193 98]、[ 0.46488294]、[ 0.47157191
] [ 0.47826087]、[ 0.48494983 ]、[ 0.4916388 ]、[ 0.49832776]、[ 0.50501672]、[ 0.51170569]、[ 0.51839465]、[ 0.52508361]、[ 0.53177258] [ 0.53846154]、 [ 0.545 1505 ]、[ 0.55183946 ]、[ 0.55852843]、[ 0.56521739]、[ 0.57190635]、[ 0.57859532]、[ 0.58528428]、[ 0.59197324]、[ 0.59866221]、





















[ 0.60535117]、
[ 0.61204013]、[
0.6187291 ]、
[ 0.62541806]、[ 0.63210702]、[
0.63879599]、[
0.64548495] 、[ 0.65217391]
、[ 0.658862 88]、[0.66555184]、[0.6722408]、[ 0.67892977 ]、[0.68561873] ]、[ 0.69230769]、[ 0.69899666]、[ 0.70568562]、[ 0.71237458]、[ 0.71906355]、[ 0.72575251]、[ 0.73244147] 、[ 0.73913043 ]、 [ 0.745 8194 ]、[ 0.75250836 ]、[ 0.75919732]、[ 0.76588629]、[ 0.77257525]、[ 0.77926421]、[ 0.78595318]、[ 0.79264214]、[ 0.7993311 ]、























[ 0.80602007]、
[ 0.81270903]、
[ 0.81939799]、
[ 0.82608696]、[ 0.83277592]、
[ 0.83946488]、[
0.84615385]
、[ 0.85284281
] 、 [ 0.859531 77]、[0.86622074]、
[0.8729097]、[ 0.87959866]、[0.88628763] ]、[ 0.89297659]、[ 0.89966555]、[ 0.90635452]、[ 0.91304348]、[ 0.91973244]、[ 0.9264214 ]、[ 0.93311037]、[ 0.93979933] [ 0.946 48829]、[ 0.95317726]、 [ 0.95986622 ]、[0.96655518]、[ 0.97324415]、[ 0.97993311]、[ 0.98662207]、[ 0.99331104]、[ 1. ]])




















# 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样
y_data = np.square(x_data) - 0.5 + noise
# y的结构
y_data

配列([[ 0.59535036],
[ 0.46017998], [
0.47144478],
[ 0.45083795], [
0.58438217], [ 0.38570118], [ 0.43550029]
,
[
0.40597571], [ 0.335 7524 ]
、[ 0.35784864
]、
[ 0.34530231]、
[ 0.32509701] 、
[ 0.25554733]、
[ 0.32300801]、
[ 0.2299959 ]、
[ 0.35472568]、[ 0.31227671]、[
0.30385068]、[ 0.29413844]
、[ 0.18437787]、 [ 0.2813 2819]、[ 0.25605309]、 [ 0.23126361 ]、[ 0.23492797]、[ 0.18381621]、[ 0.10392937]、[ 0.13415913]、[ 0.14043649]、[ 0.11756826]、[ 0.12142749]、












[ 0.12400694]、
[ 0.08926307]、
[ 0.15581832]、
[ 0.16541106]、
[-0.02582895]、[ 0.05924725]、[- 0.04037454
]、[ 0.03799003]、[ 0.090308 32] 、[0.05984324]、 [- 0.06569464 ]、[0.07973773] 、[ 0.04297837]、[ 0.05169557]、[-0.00096191]、[-0.02049573]、[- 0.03125322]、[-0.04545588] [ - 0.02168901]、[ 0.01657517]、[-0.0431] 5181]、[-0.09123519]、[- 0.03292835]、[-0.1110189 ]、[-0.08212792]、[-0.10089535]、[- 0.17406672]、 [-0.10380731 ]、[-0.10774072]、[-0.21283138]、
























[-0.09788435]、
[-0.10196452]、
[-0.16439081]、
[-0.15431978]、
[-0.17778307] [-0.18428537]、[-0.17874028]、[-0.10490738]、[-0.250768] 32]、
[ - 0.16078044]、[-0.21572183]、[-0.15624353]、[-0.19591988]、[-0.31560742]、[-0.29593726] 、 [-0.26686787]、[ -0.2999804 ]、[ -0.30631065]、[-0.353052] 24]、[-0.31295125]、[-0.22996255]、[-0.22837061]、[-0.27266253]、[-0.31290802]、[-0.37188479]、 [- 0.20765034] 、[ -0.33860431] 、[-0.31135236]、[-0.252499] 81]、[-0.26041048]、
























[-0.31486205]、
[-0.30253306]、
[-0.41624795]、
[-0.40053837]、
[-0.29939676]、[
- 0.32615377
]、[-0.37377787] 、[
-0.32222027]、[-0.315883] 8 ] 、 [-0.43880087]、[-0.37510637]、[-0.46702321]、[-0.27058091]、[-0.52885151]、[-0.4061462 ]、[ - 0.4486374 ] 、[- 0.37819628 ]、[-0.34701947]、 [-0.324543] 64]、[-0.3901839 ]、[-0.43293107]、[-0.47881173]、[-0.45280819]、[-0.49676541]、[-0.48955669] 、[- 0.45898691] 、[ -0.37473462 ]、[-0.43801531]、[-0.447936] 55]、[-0.57343047]、






















[-0.45262969]、
[-0.40719677]、
[-0.45423461]、
[-0.45053051]、
[-0.51046881]、[-0.41584096]、[-0.53328545
]

[-0.44766406]、
[-0.501584] 63]、
[-0.42676031]、
[-0.50552613]、
[-0.36832989]、
[-0.48699296]、
[-0.41614151]、
[-0.6175621 ]、[ -
0.48304532] 、[-0.46115021]、[ -0.40948908]、[-0.420170] 24]、[-0.50411757]、[-0.44530626]、[-0.46895275]、[-0.52127771]、[-0.50064585]、[-0.42210169] [-0.58582837]、[-0.52049198]、[-0.45332091]、[-0.534658] 15]、[ - 0.5385712 ]、














[-0.5654201 ]、
[-0.54471377]、
[-0.48109194]、
[-0.44565732]、[-0.48112022
]、[-0.46471786]
、[ -
0.5452149
]、[-0.52115601]、
[-0.502349 28]、
[-0.54885558]、
[-0.5279981 ]、
[-0.53893795]、
[-0.44286416]、
[-0.45371406]、
[-0.44633111]、[ -0.57535678
] 、[-0.62918947]、[ -0.41877124]、[-0.562639] 56]、[-0.51201705]、[-0.35016007]、[-0.49188897]、[-0.55766056]、[-0.38963378]、[-0.5038024 ]、[-0.51949984] 、[-0.45229896] [-0.49193029]、[-0.534728] 83]、[-0.48957523]、














[-0.35561181]、
[-0.4622668 ]、
[-0.39177781]、
[-0.43448445]、 [- 0.49854629
]、 [-
0.49843105] 、[-0.47704375]、[ -0.36618194]、[-0.451770] 12]、[-0.41497222]、[-0.42152064]、[-0.48996608]、[-0.43010878]、[-0.42599962]、[-0.2841197 ] [ - 0.38992082]、[-0.43802592]、[-0.42448799]、[-0.295146] 76]、[-0.37154091]、[-0.25426219]、[-0.44610678]、[-0.37120566]、[-0.3531599 ]、[-0.34606119]、[-0.29637877]、[-0.3693284 ] 、 [ -0.36651142]、[-0.300251 18]、[-0.31443603]、
























[-0.40824064]、
[-0.31734053]、
[-0.40807378]、
[-0.33792031]、
[-0.22414921] 、[-0.37707072]、[
-0.26776417
]
、[-0.29152204]、
[-0.340669] 34]、
[-0.19037511]、
[-0.23552614]、
[-0.2144995 ]、
[-0.27628531]、
[-0.27329725]、
[-0.23910513]

[ -0.30009859]、[-0.30192088]、[-
0.16403744]、
[-0.325468] 93]、
[-0.25686912]、
[-0.12515146]、
[-0.21483097]、
[-0.12779443]、
[-0.28748063]、
[-0.23782354]
、 [-0.16024807]、[-
0.19062672 ]、[-
0.15066097]、
[-0.190432] 74]、
[-0.16583211]、
[-0.11201314]、
[-0.05612149]、
[-0.00847256]、
[-0.1429705 ]、
[-0.09595988]、 [-0.09583441]、[
-0.01372838 ]、[
-0.04818834
]、
[-0.118406] 53]、
[0.02184166]、
[ -0.07153294]、
[-0.11556547]、
[-0.04731049]、
[-0.10774914]、
[-0.014642 ]、[-0.01470962
] 、[-0.03259555]、[- 0.04194347]、[ 0.08987345 ]、
[ -0.02027899]、[0.02418433 ]、[ 0.04298611]、[ 0.04130101]、[ 0.18010436]、[ 0.15480307]、[ 0.02719993]、[ 0.11508363]、[ 0.04309794]、[ 0.14060578]、[ 0.093 77926]、













[ 0.13887198]、
[ 0.16148276]、
[ 0.11398259]、
[ 0.27887578]、[
0.22775177]、[ 0.20749998]、
[ 0.22107721
]
、[ 0.20854961]、 [ 0.254116 44]、[ 0.26561906]、
[ 0.27540788 ]、[ 0.26946028]、[ 0.2390275 ]、[ 0.26051795]、[ 0.34424064]、[ 0.3240088 ]、[ 0.38040554]、[ 0.35717078]、[ 0.31357911]、[ 0.43825368] [ 0.35709739]、 [ 0.481 01049 ]、[0.36024364]、 [ 0.43253108 ]、[0.39268334]、[ 0.41942572]、[ 0.41196584]、[ 0.54435941]、[ 0.49840622]、[ 0.51627957]])




















4. 基本モデルを定義する

# 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))

# 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

5. 変数の初期化

# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
# 变量初始化
if int((tf.__version__).split('.')[1]) < 12:
    init = tf.initialize_all_variables()
else:
    init = tf.global_variables_initializer()
sess = tf.Session()  # 打开TensorFlow
sess.run(init)  # 执行变量初始化

6. トレーニングを開始する

for i in range(1000):  # 梯度下降迭代一千次
    # training
    sess.run(train_step, feed_dict={
    
    xs: x_data, ys: y_data})
    # 执行梯度下降算法,并且将样本喂给损失函数
    if i % 50 == 0:
        # 每50次迭代输出代价函数的值
        print(sess.run(loss, feed_dict={
    
    xs: x_data, ys: y_data}))
0.18214862
0.010138167
0.0071248626
0.0069830194
0.0068635535
0.0067452225
0.006626569
0.0065121166
0.0064035906
0.006295418
0.0061897114
0.0060903295
0.005990808
0.0058959606
0.0058057955
0.0057200184
0.005637601
0.0055605737
0.0054863705
0.005413457

おすすめ

転載: blog.csdn.net/m0_68111267/article/details/132182341