HaaS AI之手写数字识别快速实践,在VSCode中搭建TensorFlow 2.0简单神经网络

本文将介绍如何在VSCode里面搭建TensorFlow的开发环境,并跑一个简单的神经网络来进行手写数据的识别。

1、Conda环境安装

参考HaaS AI之VSCode中搭建Python虚拟环境

2、创建TensorFlow Python虚拟环境

conda维护到TensorFlow2.0版本,基于Python3.7版本,因此线创建一个TensorFlow的Python虚拟环境,命名为tf2。

conda create --name tf2 python=3.7

2.1、激活环境

(tf2)$conda activate tf2

2.2、安装TensorFlow2.0

(tf2)$conda install tensorflow

2.3、安装Matplotlib

matplotlib,风格类似 Matlab 的基于 Python 的图表绘图系统。

matplotlib 是 Python最著名的绘图库,它提供了一整套和 matlab 相似的命 API,十分适合交互式地进行制图。而且也可以方便地将它作为绘图控件,嵌入 GUI 应用程序中,在模型训练中常常用来绘制图形。

(tf2)$conda install matplotlib

3、TensorFlow之初体验

TensorFlow是Google开源的深度学习框架,是一个端到端平台,无论您是专家还是初学者,它都可以让您轻松地构建和部署机器学习模型。

3.1、简单手写数字识别网络

在VSCode中训练一个简单的手写数字识别网络模型:

1. 加载TensorFlow
In [1]:
#Mac OS KMP设置
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 安装 TensorFlow
import tensorflow as tf

2. 载入并准备好 MNIST 数据集。将样本从整数转换为浮点数:
In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

3. 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型。为训练选择优化器和损失函数:
In [3]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

4. 训练并验证模型:
In [4]:
model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test,  y_test, verbose=2)

# 输出结果
Out[4]:
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 9s 154us/sample - loss: 0.3008 - accuracy: 0.9120
Epoch 2/5
60000/60000 [==============================] - 9s 147us/sample - loss: 0.1444 - accuracy: 0.9579
Epoch 3/5
60000/60000 [==============================] - 10s 170us/sample - loss: 0.1073 - accuracy: 0.9676
Epoch 4/5
60000/60000 [==============================] - 10s 174us/sample - loss: 0.0890 - accuracy: 0.9726
Epoch 5/5
60000/60000 [==============================] - 11s 180us/sample - loss: 0.0765 - accuracy: 0.9764
10000/1 - 1s - loss: 0.0379 - accuracy: 0.9777
[0.0705649911917746, 0.9777]

3.2、模型保存

model.save('tf_mnist_simple_net.h5')

3.3、模型预测

3.3.1、显示待测图片

从测试集中选择索引号为image_index的图片进行测试。

扫描二维码关注公众号,回复: 12465721 查看本文章
5. 模型预测
# 定义plot_image函数,查看指定个数数据图像

import matplotlib.pyplot as plt #导入matplotlib.pyplot
def plot_image(image):                  #输入参数为image
    pic=plt.gcf()                       #获取当前图像
    pic.set_size_inches(2,2)            ##设置图片大
    
    plt.imshow(image, cmap='binary')    #使用plt.imshow显示图片
    plt.show()                          #设置图片大
    
# 测试集中图片索引 0~10000
In [1]:
image_index=23

# 显示待预测值
plot_image(x_test[image_index])

3.3.2、打印测试结果

pred = model.predict_classes(x_test)

#打印预测结果
print(pred)
print("测试数字结果:")
print(pred[image_index])

# 输出结果
Out [1]:
[7 2 1 ... 4 5 6]
测试数字结果:
5

为了节省训练时间,把eporch迭代次数改为1,创建一个Jupyter notebook执行1次迭代训练上述模型:

https://v.youku.com/v_show/id_XNTA5Mzk2NzU2NA==.html

注意:

在创建*.ipynb和*.py文件的名称不能是tensorflow.ipynb/tensorflow.py,否则会出现各种库找不到的情形。

3.3.3、测试代码

将以上代码合在同一个文件中(去掉输出结果部分)就可以进行测试了。

4、FQA

Q1: Mac OS上在执行模型训练时出现错误

OMP: Error #15: Initializing libiomp5.dylib, but found libiomp5.dylib already initialized.

OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.

Abort trap: 6

A1:

大概意思就是初始化libiomp5.dylib时发现已经初始化过了。

经过Google发现这似乎是一个Mac OS 才存在的特殊问题,在代码头部加入:

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

5、开发者技术支持

如需更多技术支持,可加入钉钉开发者群,或者关注微信公众号

更多技术与解决方案介绍,请访问阿里云AIoT首页https://iot.aliyun.com/

猜你喜欢

转载自blog.csdn.net/HaaSTech/article/details/113546062