使用caffe的python layer自定义数据增强层DataAugmentationLayer

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhongqianli/article/details/85601588

项目地址:https://github.com/zhongqianli/caffe_python_layer
caffe自定义网络层的一种方式是使用python layer,这种方式需要使用pycaffe运行,命令行的方式运行会报错。

编写DataAugmentationLayer

这个类的基类是caffe.Layer,需要编写setup,reshape,forward,backward四个方法,每个方法都有top和bottom参数,可以通过top[0].data和bottom[0].data获取一个4维的数据,分别是batch_size、通道数、高、宽。

import caffe
import json
import cv2
import numpy as np
import random

# 4 pixel pad, random crop
# img: 64x3x32x32
def zeropadding_and_crop(data):
    # # cifar10
    # # padding_img = np.pad(img, ((4, 4), (4, 4), (4, 4)), "constant", padder=0)
    padding_img = np.zeros((np.shape(data)[0], 3, 40, 40), dtype=np.uint8)
    padding_img[..., 4:36, 4:36] = data[...]
    # #
    # cv2.imshow("pad", data[0][0])
    row_rand_num = random.randrange(9)
    col_rand_num = random.randrange(9)
    croped_img = padding_img[..., row_rand_num : row_rand_num + 32, col_rand_num : col_rand_num + 32]

    return croped_img

class DataAugmentationLayer(caffe.Layer):
    def setup(self, bottom, top):
        pass
    def reshape(self, bottom, top):
        top[0].reshape(*bottom[0].data.shape)
        pass

    def forward(self, bottom, top):
        top[0].data[...] = zeropadding_and_crop(bottom[0].data)
        pass

    def backward(self, top, propagate_down, bottom):
        pass

使用自定义的网络层DataAugmentationLayer

pycaffe最好使用net.xxx的方式创建网络,因为第二种方式会自动命名,可能会出现一下意想不到的问题。

# 第一种方式,推荐使用
net.data = L.Python(net.data_temp,
                            python_param=dict(module="custom_data_augmentation",
                                              layer="DataAugmentationLayer"),
                            include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))

# 第二种方式,不推荐这种方式
data = L.Python(data_temp,
                            python_param=dict(module="custom_data_augmentation",
                                              layer="DataAugmentationLayer"),
                            include=dict(phase=caffe_pb2.Phase.Value('TRAIN')))

猜你喜欢

转载自blog.csdn.net/zhongqianli/article/details/85601588