import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #让TensorFlow少打印出一些信息
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
plt.rcParams['font.size'] = 16
plt.rcParams['font.family'] = ['STKaiti']
plt.rcParams['axes.unicode_minus'] = False
# 加载 MNIST 数据集
(x,y),_ = datasets.mnist.load_data() #(x,y)得到一个训练样本,包括60k张图片;"_"得到一个10K图片的测试数据,用不到,所以不赋值
# print(x.shape,y.shape)
#转化为浮点,张量。范围是0-1
x = tf.convert_to_tensor(x,dtype = tf.float32) / 255
#转化为int张量。范围是0-1
y= tf.convert_to_tensor(y,dtype = tf.int32)
# one-hot 编码
y = tf.one_hot(y, depth =10)
print('datasets:',x.shape,y.shape) #x:60k张图片,每张28*28; y:10个label,长度为60k
# 构建数据集对象
train_dataset = tf.data.Dataset.from_tensor_slices((x,y))
# for step,(i,j) in enumerate(train_dataset):
# print(step,i.shape,i.shape[0],j,j.shape[0])
print (train_dataset)
#分组训练
train_dataset = train_dataset.batch(200)
print (train_dataset)
# 利用 Sequential 容器封装 3 个网络层,前网络层的输出默认作为下一层的输入
model = tf.keras.Sequential([layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)
])
optimizer = optimizers.SGD(learning_rate = 0.001)
def train_epoch(epoch):
for step,(x,y) in enumerate(train_dataset):
# 构建梯度记录环境
with tf.GradientTape() as tape:
#打平操作step1:计算output
# [b, 28, 28] => [b, 784] b:表示多少张图片15680(60K) 10:lable个数
x = tf.reshape(x, (-1, 784)) #推断出: x.shape[0] = 200 x.shape[1] = 784
#输出模型[b, 784] => [b, 10]
out = model(x)
#step2:计算误差
loss = tf.reduce_sum(tf.square(y - out)) / x.shape[0]
# Step3. optimize and update w1, w2, w3, b1, b2, b3
# 计算参数的梯度 w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
# 更新网络参数
optimizer.apply_gradients(zip(grads, model.trainable_variables)) #zip:https://tensorflow.google.cn/versions/r2.2/api_docs/python/tf/data/Dataset#zip
if step % 50 == 0:
print("epoch:{0} step: {1} loss:{2}".format(epoch,step,loss.numpy()))
return loss.numpy()
def train():
losses = []
for i in range(50):
loss = train_epoch(i)
losses.append(loss)
x = [i for i in range(50)]
plt.plot(x,losses,color = 'blue',label = "训练误差")
plt.xlabel('EPOCH')
plt.ylabel('MSE(均方误差)')
plt.legend()
plt.show()
if __name__ == '__main__':
train()
输出:
epoch:0 step: 0 loss:1.7199777364730835 epoch:0 step: 50 loss:0.9313701391220093 epoch:0 step: 100 loss:0.9785110354423523 epoch:0 step: 150 loss:0.8963364958763123 epoch:0 step: 200 loss:0.748157799243927 epoch:0 step: 250 loss:0.7246544361114502 epoch:1 step: 0 loss:0.6489404439926147 epoch:1 step: 50 loss:0.6167117357254028 epoch:1 step: 100 loss:0.7173779010772705 epoch:1 step: 150 loss:0.7042184472084045 epoch:1 step: 200 loss:0.5774279832839966 epoch:1 step: 250 loss:0.5801646709442139 epoch:2 step: 0 loss:0.5330718159675598 epoch:2 step: 50 loss:0.5203613042831421 epoch:2 step: 100 loss:0.6196319460868835 epoch:2 step: 150 loss:0.6243830919265747 epoch:2 step: 200 loss:0.5045279860496521 epoch:2 step: 250 loss:0.5140419006347656 epoch:3 step: 0 loss:0.4760633707046509 epoch:3 step: 50 loss:0.4684053659439087 epoch:3 step: 100 loss:0.5630412101745605 epoch:3 step: 150 loss:0.5761500597000122 epoch:3 step: 200 loss:0.46055930852890015 epoch:3 step: 250 loss:0.47381022572517395 epoch:4 step: 0 loss:0.43960344791412354 epoch:4 step: 50 loss:0.43345627188682556 epoch:4 step: 100 loss:0.5239875316619873 epoch:4 step: 150 loss:0.5420948266983032 epoch:4 step: 200 loss:0.42897552251815796 epoch:4 step: 250 loss:0.4450124502182007 epoch:5 step: 0 loss:0.4126030206680298 epoch:5 step: 50 loss:0.4073203206062317 epoch:5 step: 100 loss:0.4947015643119812 epoch:5 step: 150 loss:0.5155577659606934 epoch:5 step: 200 loss:0.4044381082057953 epoch:5 step: 250 loss:0.42251351475715637 epoch:6 step: 0 loss:0.3913921117782593 epoch:6 step: 50 loss:0.3864135444164276 epoch:6 step: 100 loss:0.4714548885822296 epoch:6 step: 150 loss:0.4936997592449188 epoch:6 step: 200 loss:0.38460472226142883 epoch:6 step: 250 loss:0.40422606468200684 epoch:7 step: 0 loss:0.374055951833725 epoch:7 step: 50 loss:0.3689883351325989 epoch:7 step: 100 loss:0.45241469144821167 epoch:7 step: 150 loss:0.47510436177253723 epoch:7 step: 200 loss:0.36825084686279297 epoch:7 step: 250 loss:0.38895943760871887 epoch:8 step: 0 loss:0.3592884838581085 epoch:8 step: 50 loss:0.3542090952396393 epoch:8 step: 100 loss:0.43616557121276855 epoch:8 step: 150 loss:0.4589134156703949 epoch:8 step: 200 loss:0.35442519187927246 epoch:8 step: 250 loss:0.3759245276451111 epoch:9 step: 0 loss:0.3465850055217743 epoch:9 step: 50 loss:0.34148329496383667 epoch:9 step: 100 loss:0.4219712018966675 epoch:9 step: 150 loss:0.4446076452732086 epoch:9 step: 200 loss:0.34259235858917236 epoch:9 step: 250 loss:0.3645056188106537 epoch:10 step: 0 loss:0.3353864252567291 epoch:10 step: 50 loss:0.33033058047294617 epoch:10 step: 100 loss:0.40937140583992004 epoch:10 step: 150 loss:0.4317995309829712 epoch:10 step: 200 loss:0.33232390880584717 epoch:10 step: 250 loss:0.35430702567100525 epoch:11 step: 0 loss:0.3254787027835846 epoch:11 step: 50 loss:0.3204648494720459 epoch:11 step: 100 loss:0.39805054664611816 epoch:11 step: 150 loss:0.42023971676826477 epoch:11 step: 200 loss:0.32330557703971863 epoch:11 step: 250 loss:0.34521380066871643 epoch:12 step: 0 loss:0.31646206974983215 epoch:12 step: 50 loss:0.31164249777793884 epoch:12 step: 100 loss:0.38795238733291626 epoch:12 step: 150 loss:0.4097321331501007 epoch:12 step: 200 loss:0.3151874244213104 epoch:12 step: 250 loss:0.3370356857776642 epoch:13 step: 0 loss:0.3082444369792938 epoch:13 step: 50 loss:0.30371588468551636 epoch:13 step: 100 loss:0.37882816791534424 epoch:13 step: 150 loss:0.4001832604408264 epoch:13 step: 200 loss:0.3078741431236267 epoch:13 step: 250 loss:0.329569011926651 epoch:14 step: 0 loss:0.3007279634475708 epoch:14 step: 50 loss:0.2965642511844635 epoch:14 step: 100 loss:0.37047308683395386 epoch:14 step: 150 loss:0.3914809823036194 epoch:14 step: 200 loss:0.3012174367904663 epoch:14 step: 250 loss:0.32272788882255554 epoch:15 step: 0 loss:0.293812096118927 epoch:15 step: 50 loss:0.29001760482788086 epoch:15 step: 100 loss:0.36286717653274536 epoch:15 step: 150 loss:0.38342535495758057 epoch:15 step: 200 loss:0.29502418637275696 epoch:15 step: 250 loss:0.31643640995025635 epoch:16 step: 0 loss:0.28750962018966675 epoch:16 step: 50 loss:0.28403240442276 epoch:16 step: 100 loss:0.35584595799446106 epoch:16 step: 150 loss:0.37602511048316956 epoch:16 step: 200 loss:0.2893253564834595 epoch:16 step: 250 loss:0.3106449246406555 epoch:17 step: 0 loss:0.2817070782184601 epoch:17 step: 50 loss:0.27858632802963257 epoch:17 step: 100 loss:0.3493804633617401 epoch:17 step: 150 loss:0.36918506026268005 epoch:17 step: 200 loss:0.2840654253959656 epoch:17 step: 250 loss:0.305281400680542 epoch:18 step: 0 loss:0.2763095200061798 epoch:18 step: 50 loss:0.2735699713230133 epoch:18 step: 100 loss:0.34329670667648315 epoch:18 step: 150 loss:0.3628040552139282 epoch:18 step: 200 loss:0.2791421413421631 epoch:18 step: 250 loss:0.3002479076385498 epoch:19 step: 0 loss:0.27124762535095215 epoch:19 step: 50 loss:0.26893100142478943 epoch:19 step: 100 loss:0.33756858110427856 epoch:19 step: 150 loss:0.35681116580963135 epoch:19 step: 200 loss:0.2745401859283447 epoch:19 step: 250 loss:0.2955167889595032 epoch:20 step: 0 loss:0.26650431752204895 epoch:20 step: 50 loss:0.26464757323265076 epoch:20 step: 100 loss:0.33222049474716187 epoch:20 step: 150 loss:0.3511451780796051 epoch:20 step: 200 loss:0.27024975419044495 epoch:20 step: 250 loss:0.2911587953567505 epoch:21 step: 0 loss:0.26208287477493286 epoch:21 step: 50 loss:0.2606638967990875 epoch:21 step: 100 loss:0.32723042368888855 epoch:21 step: 150 loss:0.3457576632499695 epoch:21 step: 200 loss:0.26619842648506165 epoch:21 step: 250 loss:0.28707775473594666 epoch:22 step: 0 loss:0.2578853368759155 epoch:22 step: 50 loss:0.25689783692359924 epoch:22 step: 100 loss:0.32248175144195557 epoch:22 step: 150 loss:0.34066954255104065 epoch:22 step: 200 loss:0.26236942410469055 epoch:22 step: 250 loss:0.28327441215515137 epoch:23 step: 0 loss:0.25392603874206543 epoch:23 step: 50 loss:0.25330886244773865 epoch:23 step: 100 loss:0.31797733902931213 epoch:23 step: 150 loss:0.33589282631874084 epoch:23 step: 200 loss:0.2587423622608185 epoch:23 step: 250 loss:0.2796774208545685 epoch:24 step: 0 loss:0.2501923441886902 epoch:24 step: 50 loss:0.24991759657859802 epoch:24 step: 100 loss:0.3137132525444031 epoch:24 step: 150 loss:0.3313567042350769 epoch:24 step: 200 loss:0.2552911341190338 epoch:24 step: 250 loss:0.27629896998405457 epoch:25 step: 0 loss:0.24663478136062622 epoch:25 step: 50 loss:0.24671924114227295 epoch:25 step: 100 loss:0.3096030354499817 epoch:25 step: 150 loss:0.32707393169403076 epoch:25 step: 200 loss:0.25201329588890076 epoch:25 step: 250 loss:0.27306386828422546 epoch:26 step: 0 loss:0.24330151081085205 epoch:26 step: 50 loss:0.2436951994895935 epoch:26 step: 100 loss:0.3057035803794861 epoch:26 step: 150 loss:0.3230525553226471 epoch:26 step: 200 loss:0.2489604949951172 epoch:26 step: 250 loss:0.26997336745262146 epoch:27 step: 0 loss:0.24017933011054993 epoch:27 step: 50 loss:0.24086245894432068 epoch:27 step: 100 loss:0.3019971549510956 epoch:27 step: 150 loss:0.3192491829395294 epoch:27 step: 200 loss:0.24607966840267181 epoch:27 step: 250 loss:0.2670409381389618 epoch:28 step: 0 loss:0.23718757927417755 epoch:28 step: 50 loss:0.23816055059432983 epoch:28 step: 100 loss:0.29846417903900146 epoch:28 step: 150 loss:0.3156398832798004 epoch:28 step: 200 loss:0.2433423101902008 epoch:28 step: 250 loss:0.2642318904399872 epoch:29 step: 0 loss:0.23433589935302734 epoch:29 step: 50 loss:0.23557190597057343 epoch:29 step: 100 loss:0.2951340973377228 epoch:29 step: 150 loss:0.31218165159225464 epoch:29 step: 200 loss:0.24074004590511322 epoch:29 step: 250 loss:0.26153624057769775 epoch:30 step: 0 loss:0.2316044420003891 epoch:30 step: 50 loss:0.2330857217311859 epoch:30 step: 100 loss:0.2919342815876007 epoch:30 step: 150 loss:0.30888286232948303 epoch:30 step: 200 loss:0.23822054266929626 epoch:30 step: 250 loss:0.2589412033557892 epoch:31 step: 0 loss:0.22902503609657288 epoch:31 step: 50 loss:0.2306947112083435 epoch:31 step: 100 loss:0.2888508141040802 epoch:31 step: 150 loss:0.3056912124156952 epoch:31 step: 200 loss:0.23580516874790192 epoch:31 step: 250 loss:0.2564700245857239 epoch:32 step: 0 loss:0.22656036913394928epoch:32 step: 50 loss:0.22840747237205505 epoch:32 step: 100 loss:0.2858718931674957 epoch:32 step: 150 loss:0.30260971188545227 epoch:32 step: 200 loss:0.2334989607334137 epoch:32 step: 250 loss:0.2540746033191681 epoch:33 step: 0 loss:0.2242344319820404 epoch:33 step: 50 loss:0.226209357380867 epoch:33 step: 100 loss:0.2829955220222473 epoch:33 step: 150 loss:0.2996683418750763 epoch:33 step: 200 loss:0.23128774762153625 epoch:33 step: 250 loss:0.25178077816963196 epoch:34 step: 0 loss:0.2220241129398346 epoch:34 step: 50 loss:0.22408610582351685 epoch:34 step: 100 loss:0.2802210748195648 epoch:34 step: 150 loss:0.29684245586395264 epoch:34 step: 200 loss:0.2291480004787445 epoch:34 step: 250 loss:0.24959160387516022 epoch:35 step: 0 loss:0.21992824971675873 epoch:35 step: 50 loss:0.22202911972999573 epoch:35 step: 100 loss:0.2775717079639435 epoch:35 step: 150 loss:0.29409655928611755 epoch:35 step: 200 loss:0.22706937789916992 epoch:35 step: 250 loss:0.24749860167503357 epoch:36 step: 0 loss:0.217921182513237 epoch:36 step: 50 loss:0.22006523609161377 epoch:36 step: 100 loss:0.27507728338241577 epoch:36 step: 150 loss:0.291427880525589 epoch:36 step: 200 loss:0.22507411241531372 epoch:36 step: 250 loss:0.24548709392547607 epoch:37 step: 0 loss:0.21601375937461853 epoch:37 step: 50 loss:0.21818655729293823 epoch:37 step: 100 loss:0.2726643681526184 epoch:37 step: 150 loss:0.28883031010627747 epoch:37 step: 200 loss:0.2231394648551941 epoch:37 step: 250 loss:0.24356289207935333 epoch:38 step: 0 loss:0.2142055481672287 epoch:38 step: 50 loss:0.21635133028030396 epoch:38 step: 100 loss:0.27033886313438416 epoch:38 step: 150 loss:0.28634345531463623 epoch:38 step: 200 loss:0.22129690647125244 epoch:38 step: 250 loss:0.24170459806919098 epoch:39 step: 0 loss:0.21246902644634247 epoch:39 step: 50 loss:0.2145748883485794 epoch:39 step: 100 loss:0.26811257004737854 epoch:39 step: 150 loss:0.28393447399139404 epoch:39 step: 200 loss:0.21950072050094604 epoch:39 step: 250 loss:0.23988944292068481 epoch:40 step: 0 loss:0.2107982635498047 epoch:40 step: 50 loss:0.21286651492118835 epoch:40 step: 100 loss:0.26596397161483765 epoch:40 step: 150 loss:0.2815980017185211 epoch:40 step: 200 loss:0.21776540577411652 epoch:40 step: 250 loss:0.23814569413661957 epoch:41 step: 0 loss:0.20919860899448395 epoch:41 step: 50 loss:0.2112150937318802 epoch:41 step: 100 loss:0.2638876438140869 epoch:41 step: 150 loss:0.2793533205986023 epoch:41 step: 200 loss:0.21608133614063263 epoch:41 step: 250 loss:0.2364545464515686 epoch:42 step: 0 loss:0.20766673982143402 epoch:42 step: 50 loss:0.20960120856761932 epoch:42 step: 100 loss:0.2618783116340637 epoch:42 step: 150 loss:0.277192622423172 epoch:42 step: 200 loss:0.2144385129213333 epoch:42 step: 250 loss:0.23480544984340668 epoch:43 step: 0 loss:0.2062055617570877 epoch:43 step: 50 loss:0.20803329348564148 epoch:43 step: 100 loss:0.2599318325519562 epoch:43 step: 150 loss:0.275102436542511 epoch:43 step: 200 loss:0.21282997727394104 epoch:43 step: 250 loss:0.23318250477313995 epoch:44 step: 0 loss:0.20478737354278564 epoch:44 step: 50 loss:0.2065192013978958 epoch:44 step: 100 loss:0.25803014636039734 epoch:44 step: 150 loss:0.2730681598186493 epoch:44 step: 200 loss:0.21127662062644958 epoch:44 step: 250 loss:0.2316143959760666 epoch:45 step: 0 loss:0.20340576767921448 epoch:45 step: 50 loss:0.20504623651504517 epoch:45 step: 100 loss:0.25620490312576294 epoch:45 step: 150 loss:0.27108025550842285 epoch:45 step: 200 loss:0.2097698450088501 epoch:45 step: 250 loss:0.2300863116979599 epoch:46 step: 0 loss:0.2020624577999115 epoch:46 step: 50 loss:0.20360326766967773 epoch:46 step: 100 loss:0.25442180037498474 epoch:46 step: 150 loss:0.26915132999420166 epoch:46 step: 200 loss:0.20831821858882904 epoch:46 step: 250 loss:0.22861045598983765 epoch:47 step: 0 loss:0.2007581740617752 epoch:47 step: 50 loss:0.20221735537052155 epoch:47 step: 100 loss:0.2526772618293762 epoch:47 step: 150 loss:0.26727959513664246 epoch:47 step: 200 loss:0.20690889656543732 epoch:47 step: 250 loss:0.22717107832431793 epoch:48 step: 0 loss:0.19950994849205017 epoch:48 step: 50 loss:0.20086681842803955 epoch:48 step: 100 loss:0.25099068880081177 epoch:48 step: 150 loss:0.26547253131866455 epoch:48 step: 200 loss:0.20553314685821533 epoch:48 step: 250 loss:0.2257540076971054 epoch:49 step: 0 loss:0.19829510152339935 epoch:49 step: 50 loss:0.19956813752651215 epoch:49 step: 100 loss:0.24935641884803772 epoch:49 step: 150 loss:0.26369330286979675 epoch:49 step: 200 loss:0.2042004019021988 epoch:49 step: 250 loss:0.22436237335205078