DL实战1:tensorflow在mnist上实现siamese net

一、写在前面

参考:https://www.cnblogs.com/bentuwuying/p/8186364.html

简介:全连接孪生网络(siamese network)是一种相似性度量方法,适用于类别数目多但是每类的样本数少的分类问题。

代码:https://github.com/Shicoder/DeepLearning_Demo/tree/master/siamese_tf_mnist

二、测试代码

按照github代码运行python run.py后报错:

【报错1】

AttributeError: module ‘tensorflow‘ has no attribute ‘sub‘

AttributeError: module ‘tensorflow‘ has no attribute ‘mul‘

【解决1】

将代码inference.py中的tf.sub替换为tf.subtract,tf.mul替换为tf.multiply

----------------------------------------------------------------------------------------------------------------

【报错2】

NotFoundError (see above for traceback): Failed to create a directory: ; No such file or directory

         [[Node: save/SaveV2 = SaveV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/SaveV2/tensor_names, save/SaveV2/shape_and_slices, siamese/fc1W, siamese/fc1b, siamese/fc2W, siamese/fc2b, siamese/fc3W, siamese/fc3b)]]

【解决2】

将代码run.py中的saver.save(sess, 'model.ckpt')以及saver.restore(sess, 'model.ckpt')改为saver.save(sess, './model.ckpt')以及saver.restore(sess, './model.ckpt')

----------------------------------------------------------------------------------------------------------------

【报错3】

name 'raw_input' is not defined

【解决3】

将代码中的raw_input改为input即可。

----------------------------------------------------------------------------------------------------------------

【报错4】

读取model.ckpt失败

【解决4】

将代码中的os.path.isfile(model_ckpt)改为判断当前路径下是否存在checkpoint这个文件os.path.exists('E:\pyProject\DeepLearning_Demo-master\siamese_tf_mnist\checkpoint'): 

----------------------------------------------------------------------------------------------------------------

【报错5】

读取embed.txt数据错误

【解决5】

embed.tofile("embed.txt")#保存为二进制文件,且不能保存当前数据的行列信息

用np.fromfile读取数据需要手动指定dtype,如果指定的格式与保存时的不一致,则读出来的就是错误的数据。

print(embed.shape)

print(embed.dtype)

embed=np.fromfile('embed.txt',dtype=np.float32) #读取数据

注意到读出来的数据是一维数组,需要利用np.reshape方法重新指定维数


三、部分代码解读

from __future__ import absolute_import #使用py3的绝对引入
from __future__ import division #使用py3的精确除法
from __future__ import print_function #使用py3的print()功能函数

__future__模块,把下一个新版本的特性导入到当前版本,从而能够在当前旧版本中测试一些新版本的特性。


四、实验结果








猜你喜欢

转载自blog.csdn.net/weixin_38493025/article/details/80576266
net