Caffe官方教程翻译(3):Siamese Network Training with Caffe

前言

最近打算重新跟着官方教程学习一下caffe,顺便也自己翻译了一下官方的文档。自己也做了一些标注,都用斜体标记出来了。中间可能额外还加了自己遇到的问题或是运行结果之类的。欢迎交流指正,拒绝喷子!
官方教程的原文链接:http://caffe.berkeleyvision.org/gathered/examples/siamese.html

Siamese Network Training with Caffe

这个示例将会展示给你如何使用权重分享和对比损失函数,来学习在Caffe中使用一个siamese网络。

我们默认你已经成功编译了caffe的源码。如果没有,请查Installation page。这个例子是在MNIST tutorial的基础之上做的,所以在阅读这篇教程之前最好先看下那篇教程。

我们指定所有的路径并假设所有的命令都是在caffe的根目录下的。

准备数据集

首先你需要从MNIST网站下载并转换数据集格式。为了做到这个,运行下面的命令:

./data/mnist/get_mnist.sh
./examples/siamese/create_mnist_siamese.sh

在运行这些脚本后你会看到多出来了两个数据集:./examples/siamese/mnist_siamese_train_leveldb./examples/siamese/mnist_siamese_test_leveldb

模型

首先,我们要定义后面在训练想要使用的siamese网络的模型。我们会使用定义在./examples/siamese/mnist_siamese.prototxt中的卷积神经网络。这个模型基本上与LeNet model中的模型是一样的,唯一的区别就是我们将顶层对应10个手写数字类别概率的输出更换成了线性“特征”层,只有2个输出了(补充:siamese与LeNet的不同在于,输入变成了一对图片,不是预测单个样本对应的标签,而是判断这一对样本是否是来自同一个类,是则结果为0,不是则结果为1)。

layer {
  name: "feat"
  type: "InnerProduct"
  bottom: "ip2"
  top: "feat"
  param {
    name: "feat_w"
    lr_mult: 1
  }
  param {
    name: "feat_b"
    lr_mult: 2
  }
  inner_product_param {
    num_output: 2
  }
}

定义Siamese网络

在这个部分我们将要定义siamese网络,并用于训练。网络定义在./examples/siamese/mnist_siamese_train_test.prototxt

读入一对数据

我们最开始需要定义一个data层,而data层会从我们之前创建的LevelDB数据库读取数据。这个数据库的每个条目都包含了一对图像(pair_data)和一个二进制标签,表示它们是否是来自同一个类或是不同的类(sim)。

layer {
  name: "pair_data"
  type: "Data"
  top: "pair_data"
  top: "sim"
  include { phase: TRAIN }
  transform_param {
    scale: 0.00390625
  }
  data_param {
    source: "examples/siamese/mnist_siamese_train_leveldb"
    batch_size: 64
  }
}

为了从数据库中取出一对图片打包送到同一个blob中,我们对每一个通道都打包一个图像。我们想要分别处理这两个图像,所以在data层之后添加了一个slice层。它会接受pair_data的数据,并将其根据通道维度切分开来,然后我们会在data和它成对的图像data_p上得到一个单一的图像。

layer {
  name: "slice_pair"
  type: "Slice"
  bottom: "pair_data"
  top: "data"
  top: "data_p"
  slice_param {
    slice_dim: 1
    slice_point: 1
  }
}

建立Siamese网络的第一边

现在我们可以指定siamese网络的第一边了。这一边的网络处理的是data(补充:输入是data层),生成的是feat(输出是feat层)。刚开始,我们在./examples/siamese/mnist_siamese.prototxt中加入了默认的权重初始值。然后,我们给卷积层和内积层(全连接层)进行命名。对参数进行命名,就相当于告诉Caffe这些层上的参数在两边的siamese网络上共享。就像这样定义:

...
param { name: "conv1_w" ...  }
param { name: "conv1_b" ...  }
...
param { name: "conv2_w" ...  }
param { name: "conv2_b" ...  }
...
param { name: "ip1_w" ...  }
param { name: "ip1_b" ...  }
...
param { name: "ip2_w" ...  }
param { name: "ip2_b" ...  }
...

建立Siamese网络的第二边

现在我们需要创建第二边网络,而这一边的网络处理的是data_p,生成的是feat_p。这一边跟第一边基本上是一模一样的。因此,直接复制粘贴就行。然后我们要改一下每一层的名字、输入、输出,在名字后面加上“_p”来跟原始的区分一下。

添加对比损失函数

为了训练网络,我们要对一个对比损失函数进行优化,由Raia Hadsell, Sumit ChopraYann LeCun在“Dimensionality Reduction by Learning an Invariant Mapping”中提出。这个损失函数会使得在特征空间中相互匹配的样本更加接近,同时也会使得不匹配的样本更远。这个损耗函数在CONTRASTIVE_LOSS层声明了:

layer {
  name: "loss"
  type: "ContrastiveLoss"
  contrastive_loss_param {
    margin: 1.0
  }
  bottom: "feat"
  bottom: "feat_p"
  bottom: "sim"
  top: "loss"
}

定义解决方案

除了给解决方案指定正确的模型文件之外,没有什么特别的需要做的了。解决方案定义在:./examples/siamese/mnist_siamese_solver.prototxt

训练和测试模型

在你已经写好了网络定义的protobuf和解决方案的protobuf之后,训练模型变得很简单了。运行:./examples/siamese/train_mnist_siamese.sh

./examples/siamese/train_mnist_siamese.sh

绘制结果

首先,通过运行下面的命令,我们可以画出模型和siamese网络,画出来.prototxt文件中定义的DAGs

./python/draw_net.py \
    ./examples/siamese/mnist_siamese.prototxt \
    ./examples/siamese/mnist_siamese.png

./python/draw_net.py \
    ./examples/siamese/mnist_siamese_train_test.prototxt \
    ./examples/siamese/mnist_siamese_train_test.png

接着,我们可以在iPython notebook中导入训练好的模型并画出特征:

ipython notebook ./examples/siamese/mnist_siamese.ipynb

补充:ipython notebook太老了,可以直接换成jupyter notebook。指令如下:

jupyter notebook ./examples/siamese/mnist_siamese.ipynb

运行结果截图

这些是我在自己笔记本上运行的结果,仅供参考。
运行代码:

./python/draw_net.py \
    ./examples/siamese/mnist_siamese.prototxt \
    ./examples/siamese/mnist_siamese.png

生成的网络结构的图片:
这里写图片描述

运行代码:

./python/draw_net.py \
    ./examples/siamese/mnist_siamese_train_test.prototxt \
    ./examples/siamese/mnist_siamese_train_test.png

生成的网络结构的图片:
这里写图片描述

运行代码:

jupyter notebook ./examples/siamese/mnist_siamese.ipynb

直接看图就知道了,每种颜色分别对应一个手写数字,正好10种吧。
这里写图片描述

猜你喜欢

转载自blog.csdn.net/hongbin_xu/article/details/79363485