[Aprendizado profundo] Experimento 05 Construindo um exemplo de rede neural

construir rede neural

Observação: este código é usado para treinar uma rede neural. A rede se ajusta a y = x ^ 2-0,5 + ruído. A estrutura da rede neural é que a camada de entrada é um neurônio, a camada oculta é de dez neurônios e a camada de saída é camada é um Neurônio

1. Importe bibliotecas relacionadas

# 导入相关库
import tensorflow as tf  # 用来构造神经网络
import numpy as np  # 用来构造数据结构和处理数据模块

2. Defina uma camada

# 定义一个层
def add_layer(inputs, in_size, out_size, activation_function=None):
    # 定义一个层,其中inputs为输入,in_size为上一层神经元数,out_size为该层神经元数
    # activation_function为激励函数
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    # 初始权重随机生成比较好,in_size,out_size为该权重维度
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    # 偏置
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    # matmul为矩阵里的函数相乘
    if activation_function is None:
        outputs = Wx_plus_b  # 如果激活函数为空,则不激活,保持数据
    else:
        outputs = activation_function(Wx_plus_b)
        # 如果激活函数不为空,则激活,并且返回激活后的值
    return outputs  # 返回激活后的值

3. Construa o conjunto de dados

# 构造一些样本,用来训练神经网络
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
# 值为(-1,1)之间的数,有300个
noise = np.random.normal(0, 0.05, x_data.shape)
x_data

matriz ([[-1.],
[-0,99331104],
[-0,98662207],
[-0,97993311],
[-0,97324415], [-0,96655518]
,
[
-0,95986622], [-0,95317726],
[-0,94648829],
[ -0,93979933],
[-0,93311037],
[-0,9264214],
[-0,91973244], [
-0,91304348], [-0,90635452], [
-0,89966555
] , [
-0,89297659],
[-0,88628763], [-0,87
959866],
[ -0,8729097],
[-0,86622074],
[-0,85953177],
[-0,85284281],
[-0,84615385], [-0,83946488],
[-0,83277592
] , [
-0,82608696],
[-0,81939799],
[-0,81 270903],
[ -0,80602007],
[-0,7993311],
[-0,79264214],
[-0,78595318],
[-0,77926421], [
-0,77257525], [-0,76588629],
[-0,75919732
] , [
-0,75250836],
[-0,7458194],
[-0,7 3913043],
[-0,73244147],
[-0,72575251],
[-0,71906355],
[-0,71237458],
[-0,70568562], [-0,69899666]
,
[-0,69230769] , [
-0,68561873],
[-0,67892977],
[-0,6 722408],
[-0,66555184],
[-0,65886288],
[-0,65217391],
[-0,64548495], [
-0,63879599], [-0,63210702], [
-0,62541806
]
, [-0,6187291],
[-0,61204013],
[-0,6 0535117],
[-0,59866221],
[-0,59197324],
[-0,58528428],
[-0,57859532], [
-0,57190635], [-0,56521739], [
-0,55852843
]
, [-0,55183946],
[-0,5451505],
[-0,5 3846154],
[-0,53177258],
[-0,52508361],
[-0,51839465],
[-0,51170569],
[-0,50501672], [-0,49832776],
[-0,4916388
] , [
-0,48494983],
[-0,47826087],
[-0,4 7157191],
[-0,46488294],
[-0,45819398],
[-0,45150502],
[-0,44481605], [
-0,43812709], [-0,43143813], [-0,42474916
]
, [
-0,4180602],
[-0,41137124],
[-0,4 0468227],
[-0,39799331],
[-0,39130435],
[-0,38461538],
[-0,37792642], [
-0,37123746], [-0,36454849]
, [-0,35785953
] , [
-0,35117057],
[-0,34448161],
[-0,3 3779264],
[-0,33110368],
[-0,32441472],
[-0,31772575],
[-0,31103679],
[-0,30434783], [-0,29765886], [-0,2909699
]
, [
-0,28428094],
[-0,27759197],
[-0,2 7090301],
[-0,26421405],
[-0,25752508],
[-0,25083612],
[-0,24414716], [
-0,23745819], [-0,23076923], [-0,22408027
] , [
-0,2173913
],
[-0,21070234],
[-0,2 0401338],
[-0,19732441],
[-0,19063545],
[-0,18394649],
[-0,17725753],
[-0,17056856], [-0,1638796]
, [-0,15719064
] , [
-0,15050167],
[-0,14381271],
[-0,1 3712375],
[-0,13043478],
[-0,12374582],
[-0,11705686],
[-0,11036789],
[-0,10367893], [-0,09698997]
, [
-0,090301 ], [
-0,08361204],
[-0,07692308],
[-0,07 023411],
[-0,06354515],
[-0,05685619],
[-0,05016722],
[-0,04347826],
[-0,0367893], [-0,03010033], [
-0,02341137
]
, [-0,01672241],
[-0,01003344],
[-0,0 0334448],
[0,00334448],
[0,01003344],
[0,01672241],
[0,02341137], [0,03010033],
[0,0367893], [0,04347826]
,
[0,05016722
] , [0,05685619], [
0,0 6354515], [0,07023411] , [0,07692308], [0,08361204 ], [ 0,090301 ], [ 0,09698997], [ 0,10367893], [ 0,11036789], [ 0,11705686], [ 0,12374582], [ 0,13043478], [ 0,13712375 ] , [ 0,14381271], [ 0. 15050167] , [0,15719064], [0,1638796], [0,17056856], [0,17725753], [0,18394649], [0,19063545], [0,19732441],





















[0,20401338],
[0,21070234],
[0,2173913],
[0,22408027], [
0,23076923], [0,23745819], [
0,24414716],
[0,25083612
]
,
[0,25752508],
[0,2 6421405], [0,27090301],
[0,27759197],
[0,28428094 ],
[0,2909699],
[0,29765886], [
0,30434783],
[0,31103679], [0,31772575],
[0,32441472], [0,33110368],
[
0,33779264
], [ 0,34448161], [0 0,35117057],
[0,35785953 ], [0,36454849], [0,37123746], [0,37792642], [0,38461538], [0,39130435], [0,39799331],








[0,40468227],
[0,41137124],
[0,4180602],
[0,42474916], [
0,43143813], [0,43812709], [
0,44481605],
[0,45150502
] , [0,45819398 ], [
0,4 6488294], [0,47157191] , [0,47826087], [0,48494983 ], [0,4916388], [0,49832776], [0,50501672], [0,51170569], [0,51839465], [ 0,52508361], [0,53177258], [ 0,53846154 ] , [ 0,5451505], [0 0,55183946], [ 0,55852843 ], [0,56521739], [0,57190635], [0,57859532], [0,58528428], [0,59197324], [0,59866221],





















[0,60535117],
[0,61204013], [
0,6187291],
[0,62541806], [0,63210702], [ 0,63879599], [0,64548495],
[ 0,65217391 ] , [0,65886288 ], [ 0,6 6555184], [0,6722408] , [0,67892977], [0,68561873 ], [0,69230769], [0,69899666], [0,70568562], [0,71237458], [0,71906355], [ 0,72575251], [0,73244147] , [ 0,73913043], [ 0,7458194], [0 0,75250836], [0,75919732 ], [0,76588629], [0,77257525], [0,77926421], [0,78595318], [0,79264214], [0,7993311],

























[0,80602007],
[0,81270903],
[0,81939799],
[0,82608696], [0,83277592],
[0,83946488], [
0,84615385],
[0,85284281
] , [0,85953177] , [
0,8 6622074], [0,8729097] , [0,87959866], [0,88628763 ], [0,89297659], [0,89966555], [0,90635452], [0,91304348], [0,91973244], [ 0,9264214], [0,93311037], [ 0,93979933 ] , [ 0,94648829], [0 0,95317726], [ 0,95986622 ], [0,96655518], [0,97324415], [0,97993311], [0,98662207], [0,99331104], [1.]])




















# 加入噪声会更贴近真实情况,噪声的值为(0,0.05)之间,结构为x_data一样
y_data = np.square(x_data) - 0.5 + noise
# y的结构
y_data

matriz ([[0,59535036],
[0,46017998], [
0,47144478],
[0,45083795], [0,58438217], [
0,38570118], [0,43550029]
,
[
0,40597571],
[0,3357524], [ 0,35784864
],
[0,34530231],
[0,32509701] ,
[0,25554733],
[0,32300801],
[0,2299959],
[0,35472568], [0,31227671], [
0,30385068], [
0,29413844], [
0,18437787
] , [0,28132819], [0 0,25605309], [
0,23126361 ], [0,23492797], [ 0,18381621], [0,10392937], [0,13415913], [0,14043649], [0,11756826], [0,12142749],









[0,12400694],
[0,08926307],
[0,15581832],
[0,16541106], [-0,02582895],
[0,05924725], [-0,04037454] , [0,03799003],
[ 0,09030832 ], [0,0 5984324], [ -0,06569464 ], [0,07973773] , [0,04297837], [0,05169557], [-0,00096191], [-0,02049573], [-0,03125322], [-0,04545588], [-0,02168901 ] , [0,01657517], [-0,04315181], [-0 .09123519], [- 0,03292835], [-0,1110189], [-0,08212792], [-0,10089535], [-0,17406672], [-0,10380731], [-0,10774072], [-0,21283138],
























[-0,09788435],
[-0,10196452],
[-0,16439081],
[-0,15431978], [-0,17778307
], [-0,18428537]
, [-0,17874028
]
, [-0,10490738],
[-0,25076832],
[-0,1 6078044],
[-0,21572183],
[-0,15624353],
[-0,19591988],
[-0,31560742], [-0,29593726
], [-0,26686787]
,
[-0,2999804] , [
-0,30631065],
[-0,35305224],
[-0,3 1295125],
[-0,22996255],
[-0,22837061],
[-0,27266253],
[-0,31290802], [
-0,37188479], [
-0,20765034], [-0,33860431
],
[-0,31135236],
[-0,25249981],
[-0,2 6041048],
[-0,31486205],
[-0,30253306],
[-0,41624795],
[-0,40053837],
[-0,29939676], [-0,32615377],
[
-0,37377787] , [
-0,32222027],
[-0,3158838], [-0,4
3880087],
[-0,37510637],
[-0,46702321],
[-0,27058091],
[-0,52885151], [-0,4061462]
, [-0,4486374],
[-0,37819628
] , [-0,34701947
],
[-0,32454364],
[-0,3 901839],
[-0,43293107],
[-0,47881173],
[-0,45280819],
[-0,49676541], [
-0,48955669], [-0,45898691]
, [
-0,37473462
], [-0,43801531],
[-0,44793655],
[-0,5 7343047],
[-0,45262969],
[-0,40719677],
[-0,45423461],
[-0,45053051],
[-0,51046881], [-0,41584096], [-0,53328545]
,
[
-0,44766406],
[-0,50158463],
[-0,4 2676031],
[-0,50552613],
[-0,36832989],
[-0,48699296],
[-0,41614151],
[-0,6175621], [-0,48304532], [-0,46115021
]
, [
-0,40948908],
[-0,42017024],
[-0,5 0411757],
[-0,44530626],
[-0,46895275],
[-0,52127771],
[-0,50064585], [
-0,42210169], [-0,58582837]
, [-0,52049198]
, [
-0,45332091],
[-0,53465815],
[-0,5 385712],
[-0,5654201],
[-0,54471377],
[-0,48109194],
[-0,44565732], [-0,48112022
], [-0,46471786],
[-0,5452149
] , [
-0,52115601],
[-0,50234928],
[-0,5 4885558],
[-0,5279981],
[-0,53893795],
[-0,44286416],
[-0,45371406],
[-0,44633111], [-0,57535678],
[-0,62918947
] , [
-0,41877124],
[-0,56263956],
[-0,5 1201705],
[-0,35016007],
[-0,49188897],
[-0,55766056],
[-0,38963378],
[-0,5038024], [-0,51949984], [
-0,45229896]
,
[-0,49193029],
[-0,53472883],
[-0,4 8957523],
[-0,35561181],
[-0,4622668],
[-0,39177781],
[-0,43448445],
[-0,49854629], [-0,49843105], [
-0,47704375
] , [
-0,36618194],
[-0,45177012],
[-0,4 1497222],
[-0,42152064],
[-0,48996608],
[-0,43010878],
[-0,42599962],
[-0,2841197], [-0,38992082], [
-0,43802592
] , [
-0,42448799],
[-0,29514676],
[-0,3 7154091],
[-0,25426219],
[-0,44610678],
[-0,37120566],
[-0,3531599],
[-0,34606119], [-0,29637877]
,
[-0,3693284 ], [
-0,36651142],
[-0,30025118],
[-0,3 1443603],
[-0,40824064],
[-0,31734053],
[-0,40807378],
[-0,33792031], [
-0,22414921], [-0,37707072]
, [
-0,26776417]
, [-0,29152204],
[-0,34066934],
[-0,1 9037511],
[-0,23552614],
[-0,2144995],
[-0,27628531],
[-0,27329725],
[-0,23910513], [-0,30009859],
[-0,30192088
] , [
-0,16403744],
[-0,32546893],
[-0,2 5686912],
[-0,12515146],
[-0,21483097],
[-0,12779443],
[-0,28748063],
[-0,23782354], [-0,16024807], [-0,19062672
]
, [-0,15066097
],
[-0,19043274],
[-0,1 6583211],
[-0,11201314],
[-0,05612149],
[-0,00847256],
[-0,1429705], [
-0,09595988], [-0,09583441],
[-0,01372838
] , [
-0,04818834],
[-0,11840653],
[0,0 2184166],
[ -0,07153294],
[-0,11556547],
[-0,04731049],
[-0,10774914], [-0,014642],
[-0,01470962], [
-0,03259555],
[
-0,04194347],
[0,08987345], [-0,020 27899]
,
[0,02418433 ],
[ 0,04298611],
[ 0,04130101],
[ 0,18010436],
[ 0,15480307],
[ 0,02719993], [
0,11508363
], [ 0,04309794],
[ 0,14060578],
[ 0,09377926],
[0,13887198],
[0,16148276],
[0,11398259],
[0,27887578], [0,22775177],
[0,20749998], [0,22107721
]
,
[0,20854961 ], [ 0,25411644], [0,2 6561906],
[0,27540788 ], [0,26946028], [0,2390275 ], [0,26051795], [0,34424064], [0,3240088], [0,38040554], [ 0,35717078], [0,31357911], [0,43825368] , [ 0,35709739 ] , [0,48101049], [ 0 0,36024364], [0,43253108 ], [0,39268334], [0,41942572], [0,41196584], [0,54435941], [0,49840622], [0,51627957]])




















4. Defina o modelo básico

# 定义placeholder用来输入数据到神经网络,其中1表只有一个特征,也就是维度为一维数据
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# 代价函数,reduce_mean为求均值,reduce_sum为求和,reduction_indices为数据处理的维度
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))

# 将代价函数传到梯度下降,学习速率为0.1,这里包含权重的训练,会更新权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

5. Inicialização de variáveis

# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
# 变量初始化
if int((tf.__version__).split('.')[1]) < 12:
    init = tf.initialize_all_variables()
else:
    init = tf.global_variables_initializer()
sess = tf.Session()  # 打开TensorFlow
sess.run(init)  # 执行变量初始化

6. Comece a treinar

for i in range(1000):  # 梯度下降迭代一千次
    # training
    sess.run(train_step, feed_dict={
    
    xs: x_data, ys: y_data})
    # 执行梯度下降算法,并且将样本喂给损失函数
    if i % 50 == 0:
        # 每50次迭代输出代价函数的值
        print(sess.run(loss, feed_dict={
    
    xs: x_data, ys: y_data}))
0.18214862
0.010138167
0.0071248626
0.0069830194
0.0068635535
0.0067452225
0.006626569
0.0065121166
0.0064035906
0.006295418
0.0061897114
0.0060903295
0.005990808
0.0058959606
0.0058057955
0.0057200184
0.005637601
0.0055605737
0.0054863705
0.005413457

Acho que você gosta

Origin blog.csdn.net/m0_68111267/article/details/132182341
Recomendado
Clasificación