caffe添加自己的python层

本人也是入门小白一枚,写博客是为了加深自己的映像。如有错误或不妥之处请多多指教
环境准备
caffe安装及其Python接口编译请自行查找教程。caffe使用自己的Python层在编译pycaffe之前需要额外修改Makefile.config.

# Uncomment to support layers written in Python (will link against Python libs)
WITH_PYTHON_LAYER := 1  # 删除前面的#号取消注释

修改之后make编译一下:make pycaffe -j4
编译成功后在命令行中进入Python环境,然后输入import caffe,显示如下则说明caffe和python_layer都编译成功了。

zhangguoxiong@zhangguoxiong-Inspiron-5448:~$ python
Python 2.7.13 |Anaconda 4.4.0 (64-bit)| (default, Dec 20 2016, 23:09:15) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)] on linux2
Type "help", "copyright", "credits" or "license" for more information.
Anaconda is brought to you by Continuum Analytics.
Please check out: http://continuum.io/thanks and https://anaconda.org
>>> import caffe
>>>

查看/添加路径
在命令行输入vim ~/.bashrc查看你的caffe_python的系统路径,我的python路径如下:
export PYTHONPATH=~/caffe/python
我们自己定义的python层的文件要放在python路径下。
定义自己的Python层
在Python的文件中,需要定义类,类的里面包括几个必须包含的部分:

setup( ): 用于检查输入的参数是否存在异常,初始化的功能.
reshape( ): 也是初始化,设定一下参数的size
forward( ): 前向传播
backward( ): 反向传播

在这里我们简单的实现一个恒等层来进行测试,代码如下。

Mylayer.py

import caffe
class MyLayer(caffe.Layer): 
    def setup(self, bottom, top): 
    pass 
    def reshape(self, bottom, top): 
    top[0].reshape(*bottom[0].data.shape) 
    def forward(self, bottom, top): 
    top[0].data[...] = bottom[0].data[:] 
    def backward(self, top, propagate_down, bottom): 
    for i in range(len(propagate_down)): 
        if not propagate_down[i]: 
        continue 
            bottom[i].diff[...] = top[i].diff[:]

python层定义好之后我们就需要在自己的网络里使用他。我在examples/mnist下新建了一个网络来实现手写体识别,并将SoftmaxWithLoss层分解成softmax层和MultinomialLogisticLoss层。因为根据Caffe提供的官方文档,我们知SoftmaxWithLoss是softmax层和MultinomialLogisticLoss的合并。
网络文件如下:

lenet_lr.prototxt

name: "LeNet"
layer {
  name: "mnist"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  transform_param {
    scale: 0.00390625
  }
  data_param {
    source: "examples/mnist/mnist_train_lmdb"
    batch_size: 64
    backend: LMDB
  }
}
layer {
  name: "mnist"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TEST
  }
  transform_param {
    scale: 0.00390625
  }
  data_param {
    source: "examples/mnist/mnist_test_lmdb"
    batch_size: 100
    backend: LMDB
  }
}
layer {
  name: "ip"
  type: "InnerProduct"
  bottom: "data"
  top: "ip"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  inner_product_param {
    num_output: 10
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "accuracy"
  type: "Accuracy"
  bottom: "ip"
  bottom: "label"
  top: "accuracy"
  include {
    phase: TEST
  }
}
layer {
  name: "output_1"
  type: "Softmax"
  bottom: "ip"
  top: "output_1"
}
layer {
  type: "Python"
  name: "output"
  bottom: "output_1"
  top: "output"
  python_param {
    module: "Mylayer"
    layer: "MyLayer"
    param_str: '{"x1": 1, "x2": 2 }'
  }
}
layer {
  name: "loss" 
  type: "MultinomialLogisticLoss" 
  bottom: "output"  
  bottom: "label"  
  top: "loss"
}

配置文件如下:

lenet_lr_solver.prototxt

# The train/test net protocol buffer definition
net: "examples/mnist/lenet_lr.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: CPU

由于这个网络比较小,用CPU跑也很快,所以用了CPU。
接下来在caffe根目录下输入:

Inspiron-5448:~/caffe$ ./build/tools/caffe train --solver=examples/mnist/lenet_lr_solver.prototxt

十几秒就训练完了,准确率也挺高。
mnist手写体识别的数据获取可以参考《21天实战caffe》第6天:运行手写体数字识别例程
参考文献
https://www.jianshu.com/p/e05d1b210fcb?utm_campaign=hugo&utm_medium=reader_share&utm_content=note&utm_source=weibo

猜你喜欢

转载自blog.csdn.net/qq_41648043/article/details/82556017
今日推荐