版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
以前用过,又忘记了,今天再备忘一下。以数据层为例,说明如何添加自己的python层。
1:trainval.protxt
name: "mytest"
layer {
name: "data"
type: "Python"
top: "data"
top: "label"
python_param {
module:'my_image_data'
layer: 'MyImageDataLayer'
param_str: "{'feat_stride': 16, 'ratios': [0.4, 0.667, 1, 1.5, 2.5], 'scales': [2, 3, 6, 9, 16]}"
}
}
注意点:
1:type: 指定为Python
2:module: 指定为自己python文件名【我的是my_image_data.py,所以这里指定为my_image_data】
3:layer: 指定为自己写的层的类名【具体的可以参看my_image_data.py文件】
4:param_str:用于指定参数或者配置
2:solver.prototxt
net: "trainval.prototxt"
test_iter: 100
test_interval: 1000
base_lr: 0.000038
momentum: 0.9
weight_decay: 0.004
lr_policy: "step"
gamma: 0.1
stepsize: 30000
display: 500
max_iter: 200000
snapshot: 1000
snapshot_prefix: "model/mytest"
solver_mode: GPU
3:my_image_data.py
#!/usr/bin/env python
"""
"""
import os
import cv2
import glob
import time
import yaml
import numpy as np
import sys
sys.path.insert(0,'/home_enet1/Caffe-Frame/video-caffe/python')
import caffe
class MyImageDataLayer(caffe.Layer):
def setup(self,bottom,top):
layer_params = yaml.load(self.param_str)
self.scales = layer_params.get('scales', (8, 16, 32))
self.ratios = layer_params.get('ratios', ((0.5, 1, 2)))
self.feat_stride = layer_params['feat_stride']
top[0].reshape(1, 1, 100, 100)
# bbox_targets
# top[1].reshape(1, A * 4, height, width)
top[1].reshape(1, 1, 100, 100)
def reshape(self,bottom,top):
pass
def forward(self,bottom,top):
print "!!!!!!!!!!!!!!!!!!!!!!"
print self.scales, self.ratios,self.feat_stride
print "!!!!!!!!!!!!!!!!!!!!!!"
def backward(self,top,propagate_dowm,bottom):
pass
最重要的两点:
1:一定要正确指定自己的caffe路径【否则会提示No moudle named ***】
2:py文件和protoxt文件放在同一个目录【具体位置无所谓,在同一个目录就行】
import sys
sys.path.insert(0,'/home_enet1/Caffe-Frame/video-caffe/python')
import caffe
4:单GPU调用
/home_enet1/Caffe-Frame/video-caffe/build/tools/caffe train --solver=solver.prototxt --gpu 0 2>&1 | tee mytest.LOG
注意点:
这个脚本只能单GPU训练
5:多GPU调用
如果想多GPU训练,就用caffe里面自带的python/train.py文件进行调用。
python train.py --solver=solver.prototxt --snapshot= --gpu 0 1 2>&1 | tee mytest.LOG
watch -n 1 nvidia-smi
观察显存和GPU调用