一、写在前面
参考:https://www.cnblogs.com/bentuwuying/p/8186364.html
简介:全连接孪生网络(siamese network)是一种相似性度量方法,适用于类别数目多但是每类的样本数少的分类问题。
代码:https://github.com/Shicoder/DeepLearning_Demo/tree/master/siamese_tf_mnist
二、测试代码
按照github代码运行python run.py后报错:
【报错1】
AttributeError: module ‘tensorflow‘ has no attribute ‘sub‘
AttributeError: module ‘tensorflow‘ has no attribute ‘mul‘
【解决1】
将代码inference.py中的tf.sub替换为tf.subtract,tf.mul替换为tf.multiply
----------------------------------------------------------------------------------------------------------------
【报错2】
NotFoundError (see above for traceback): Failed to create a directory: ; No such file or directory
[[Node: save/SaveV2 = SaveV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/SaveV2/tensor_names, save/SaveV2/shape_and_slices, siamese/fc1W, siamese/fc1b, siamese/fc2W, siamese/fc2b, siamese/fc3W, siamese/fc3b)]]
【解决2】
将代码run.py中的saver.save(sess, 'model.ckpt')以及saver.restore(sess, 'model.ckpt')改为saver.save(sess, './model.ckpt')以及saver.restore(sess, './model.ckpt')
----------------------------------------------------------------------------------------------------------------
【报错3】
name 'raw_input' is not defined
【解决3】
将代码中的raw_input改为input即可。
----------------------------------------------------------------------------------------------------------------
【报错4】
读取model.ckpt失败
【解决4】
将代码中的os.path.isfile(model_ckpt)改为判断当前路径下是否存在checkpoint这个文件os.path.exists('E:\pyProject\DeepLearning_Demo-master\siamese_tf_mnist\checkpoint'):
----------------------------------------------------------------------------------------------------------------
【报错5】
读取embed.txt数据错误
【解决5】
embed.tofile("embed.txt")#保存为二进制文件,且不能保存当前数据的行列信息
用np.fromfile读取数据需要手动指定dtype,如果指定的格式与保存时的不一致,则读出来的就是错误的数据。
print(embed.shape)
print(embed.dtype)
embed=np.fromfile('embed.txt',dtype=np.float32) #读取数据
注意到读出来的数据是一维数组,需要利用np.reshape方法重新指定维数
三、部分代码解读
from __future__ import absolute_import #使用py3的绝对引入 from __future__ import division #使用py3的精确除法 from __future__ import print_function #使用py3的print()功能函数
__future__
模块,把下一个新版本的特性导入到当前版本,从而能够在当前旧版本中测试一些新版本的特性。