利用MatConvNet进行孪生多分支网络设计

前面提及到了利用vl_nndist作为多分支网络的特征测度函数,将多个网络的局部输出融合到一起。参见博客:https://blog.csdn.net/shenziheng1/article/details/81263547。 很多文章中也提及到了,除了采用显式的距离测度函数,我们还可以使用全连接层进行设计,其中关键的一环就是如何将多个分支网络的输出拼接成一个输出。Matconvnet中已经开发了这样的函数dagnn.Concat 和 vl_nnconcat。

1. vl_nnconcat

function y = vl_nnconcat(inputs, dim, dzdy, varargin)
%卷积神经网络中用于连接多个输入
%  Y = VL_NNCONCAT(INPUTS, DIM) 沿着维度DIM连接输入信息
%
%  DZDINPUTS = VL_NNCONCAT(INPUTS, DIM, DZDY) computes the derivatives
%  of the block projected onto DZDY. DZDINPUTS has one element for
%  each element of INPUTS, each of which is an array that has the same
%  dimensions of the corresponding array in INPUTS.

opts.inputSizes = [] ;
opts = vl_argparse(opts, varargin, 'nonrecursive') ;

if nargin < 2, dim = 3; end;
if nargin < 3, dzdy = []; end;

if isempty(dzdy)
  y = cat(dim, inputs{:});
else
  if isempty(opts.inputSizes)
    opts.inputSizes = cellfun(@(inp) [size(inp,1),size(inp,2),size(inp,3),size(inp,4)], inputs, 'UniformOutput', false) ;
  end
  start = 1 ;
  y = cell(1, numel(opts.inputSizes)) ;
  s.type = '()' ;
  s.subs = {':', ':', ':', ':'} ;
  for i = 1:numel(opts.inputSizes)
    stop = start + opts.inputSizes{i}(dim) ;
    s.subs{dim} = start:stop-1 ;
    y{i} = subsref(dzdy,s) ;
    start = stop ;
  end
end

2. dagnn.Concat

classdef Concat < dagnn.ElementWise
  properties
    dim = 3  % 默认是按照第三个维度进行拼接的 应用过程中指定维度就好了
  end

  properties (Transient)
    inputSizes = {}
  end

  methods
    function outputs = forward(obj, inputs, params)
      outputs{1} = vl_nnconcat(inputs, obj.dim) ;
      obj.inputSizes = cellfun(@size, inputs, 'UniformOutput', false) ;
    end

    function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
      derInputs = vl_nnconcat(inputs, obj.dim, derOutputs{1}, 'inputSizes', obj.inputSizes) ;
      derParams = {} ;
    end

    function reset(obj)
      obj.inputSizes = {} ;
    end

    function outputSizes = getOutputSizes(obj, inputSizes)
      sz = inputSizes{1} ;
      for k = 2:numel(inputSizes)
        sz(obj.dim) = sz(obj.dim) + inputSizes{k}(obj.dim) ;
      end
      outputSizes{1} = sz ;
    end

    function rfs = getReceptiveFields(obj)
      numInputs = numel(obj.net.layers(obj.layerIndex).inputs) ;
      if obj.dim == 3 || obj.dim == 4
        rfs = [email protected](obj) ;
        rfs = repmat(rfs, numInputs, 1) ;
      else
        for i = 1:numInputs
          rfs(i,1).size = [NaN NaN] ;
          rfs(i,1).stride = [NaN NaN] ;
          rfs(i,1).offset = [NaN NaN] ;
        end
      end
    end

    function load(obj, varargin)
      s = dagnn.Layer.argsToStruct(varargin{:}) ;
      % backward file compatibility
      if isfield(s, 'numInputs'), s = rmfield(s, 'numInputs') ; end
      [email protected](obj, s) ;
    end

    function obj = Concat(varargin)
      obj.load(varargin{:}) ;
    end
  end
end

一个应用实例:

function net = initializeUnet()

net=dagnn.DagNN(); 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  STAGE I
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 1: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv1 = dagnn.Conv('size',[3,3,1,64], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv1', conv1, {'FBP'},{'conv_x1'},{'conv_f1','conv_b1'});
net.addLayer('bn1', dagnn.BatchNorm('numChannels', 64), {'conv_x1'}, {'bn_x1'}, {'bn1f', 'bn1b', 'bn1m'}); % 注意批归一化的通道数
relu1 = dagnn.ReLU();
net.addLayer('relu1', relu1, {'bn_x1'}, {'relu_x1'}, {});
% ----------------------------------------------
% Stage 1: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv2 = dagnn.Conv('size',[3,3,64,64], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv2', conv2, {'relu_x1'},{'conv_x2'},{'conv_f2','conv_b2'});
net.addLayer('bn2', dagnn.BatchNorm('numChannels', 64), {'conv_x2'}, {'bn_x2'}, {'bn2f', 'bn2b', 'bn2m'});
relu2 = dagnn.ReLU();
net.addLayer('relu2', relu2, {'bn_x2'}, {'relu_x2'}, {});
% ----------------------------------------------
% Stage 1: pooling
% ----------------------------------------------
pool1 = dagnn.Pooling('method', 'max', 'poolSize', [2 2], 'stride', 2,'pad', 0);
net.addLayer('pool1', pool1, {'relu_x2'}, {'pool_x1'}, {});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  STAGE II
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 2: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv3 = dagnn.Conv('size',[3,3,64,128], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv3', conv3, {'pool_x1'},{'conv_x3'},{'conv_f3','conv_b3'});
net.addLayer('bn3', dagnn.BatchNorm('numChannels', 128), {'conv_x3'}, {'bn_x3'}, {'bn3f', 'bn3b', 'bn3m'});
relu3 = dagnn.ReLU();
net.addLayer('relu3', relu3, {'bn_x3'}, {'relu_x3'}, {});
% ----------------------------------------------
% Stage 2: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv4 = dagnn.Conv('size',[3,3,128,128], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv4', conv4, {'relu_x3'},{'conv_x4'},{'conv_f4','conv_b4'});
net.addLayer('bn4', dagnn.BatchNorm('numChannels', 128), {'conv_x4'}, {'bn_x4'}, {'bn4f', 'bn4b', 'bn4m'});
relu4 = dagnn.ReLU();
net.addLayer('relu4', relu4, {'bn_x4'}, {'relu_x4'}, {});
% ----------------------------------------------
% Stage 2: pooling
% ----------------------------------------------
pool2 = dagnn.Pooling('method', 'max', 'poolSize', [2 2], 'stride', 2);
net.addLayer('pool2', pool2, {'relu_x4'}, {'pool_x2'}, {});


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  STAGE III
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 3: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv5 = dagnn.Conv('size',[3,3,128,256], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv5', conv5, {'pool_x2'},{'conv_x5'},{'conv_f5','conv_b5'});
net.addLayer('bn5', dagnn.BatchNorm('numChannels', 256), {'conv_x5'}, {'bn_x5'}, {'bn5f', 'bn5b', 'bn5m'});
relu5 = dagnn.ReLU();
net.addLayer('relu5', relu5, {'bn_x5'}, {'relu_x5'}, {});
% ----------------------------------------------
% Stage 3: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv6 = dagnn.Conv('size',[3,3,256,256], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv6', conv6, {'relu_x5'},{'conv_x6'},{'conv_f6','conv_b6'});
net.addLayer('bn6', dagnn.BatchNorm('numChannels', 256), {'conv_x6'}, {'bn_x6'}, {'bn6f', 'bn6b', 'bn6m'});
relu6 = dagnn.ReLU();
net.addLayer('relu6', relu6, {'bn_x6'}, {'relu_x6'}, {});
% ----------------------------------------------
% Stage 3: pooling 
% ----------------------------------------------
pool3 = dagnn.Pooling('method', 'max', 'poolSize', [2 2], 'stride', 2);
net.addLayer('pool3', pool3, {'relu_x6'}, {'pool_x3'}, {});


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  STAGE IV
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 4: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv7 = dagnn.Conv('size',[3,3,256,512], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv7', conv7, {'pool_x3'},{'conv_x7'},{'conv_f7','conv_b7'});
net.addLayer('bn7', dagnn.BatchNorm('numChannels', 512), {'conv_x7'}, {'bn_x7'}, {'bn7f', 'bn7b', 'bn7m'});
relu7 = dagnn.ReLU();
net.addLayer('relu7', relu7, {'bn_x7'}, {'relu_x7'}, {});
% ----------------------------------------------
% Stage 4: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv8 = dagnn.Conv('size',[3,3,512,512], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv8', conv8, {'relu_x7'},{'conv_x8'},{'conv_f8','conv_b8'});
net.addLayer('bn8', dagnn.BatchNorm('numChannels', 512), {'conv_x8'}, {'bn_x8'}, {'bn8f', 'bn8b', 'bn8m'});
relu8 = dagnn.ReLU();
net.addLayer('relu8', relu8, {'bn_x8'}, {'relu_x8'}, {});
% ----------------------------------------------
% Stage 4: pooling
% ----------------------------------------------
pool4 = dagnn.Pooling('method', 'max', 'poolSize', [2 2], 'stride', 2);
net.addLayer('pool4', pool4, {'relu_x8'}, {'pool_x4'}, {});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  STAGE V
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 3: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv9 = dagnn.Conv('size',[3,3,512,1024], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv9', conv9, {'pool_x4'},{'conv_x9'},{'conv_f9','conv_b9'});
net.addLayer('bn9', dagnn.BatchNorm('numChannels', 1024), {'conv_x9'}, {'bn_x9'}, {'bn9f', 'bn9b', 'bn9m'});
relu9 = dagnn.ReLU();
net.addLayer('relu9', relu9, {'bn_x9'}, {'relu_x9'}, {});
% ----------------------------------------------
% Stage 3: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv10 = dagnn.Conv('size',[3,3,1024,512], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv10', conv10, {'relu_x9'},{'conv_x10'},{'conv_f10','conv_b10'});
net.addLayer('bn10', dagnn.BatchNorm('numChannels', 512), {'conv_x10'}, {'bn_x10'}, {'bn10f', 'bn10b', 'bn10m'});
relu10 = dagnn.ReLU();
net.addLayer('relu10', relu10, {'bn_x10'}, {'relu_x10'}, {});
% ----------------------------------------------
% Stage 3: unpooling : 注意!!! 上采样层
% ----------------------------------------------
Upsample1=dagnn.ConvTranspose('size',[3,3,512,512],'hasBias',false,'upsample',[2,2],'crop',[0,1,0,1]);
net.addLayer('unpool1', Upsample1,{'relu_x10'},{'unpool_x1'},{'f1'});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                              UPCONV STAGE IV : 上采样卷积
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 4: concat block :输入拼接
% ----------------------------------------------
concat1 = dagnn.Concat('dim', 3); % 深度
net.addLayer('concat1', concat1, {'relu_x8', 'unpool_x1'}, {'concat_x1'}, {});
% ----------------------------------------------
% Stage 4: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv11 = dagnn.Conv('size',[3,3,1024,512], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv11', conv11, {'concat_x1'}, {'conv_x11'}, {'conv_f11','conv_b11'});
net.addLayer('bn11', dagnn.BatchNorm('numChannels', 512), {'conv_x11'}, {'bn_x11'}, {'bn11f', 'bn11b', 'bn11m'});
relu11 = dagnn.ReLU();
net.addLayer('relu11', relu11, {'bn_x11'}, {'relu_x11'}, {});
% ----------------------------------------------
% Stage 4: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv12 = dagnn.Conv('size',[3,3,512,256], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv12', conv12, {'relu_x11'}, {'conv_x12'},{'conv_f12','conv_b12'});
net.addLayer('bn12', dagnn.BatchNorm('numChannels', 256), {'conv_x12'}, {'bn_x12'}, {'bn12f', 'bn12b', 'bn12m'});
relu12 = dagnn.ReLU();
net.addLayer('relu12', relu12, {'bn_x12'}, {'relu_x12'}, {});
% ----------------------------------------------
% Stage 4: unpooling : 继续进行上采样
% ----------------------------------------------
Upsample2=dagnn.ConvTranspose('size',[3,3,256,256],'hasBias',false,'upsample',[2,2],'crop',[1,0,1,0]);
net.addLayer('unpool2', Upsample2,{'relu_x12'},{'unpool_x2'},{'f2'});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                              UPCONV STAGE III 上采样卷积
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 3: concat block : 输入拼接
% ----------------------------------------------
concat2 = dagnn.Concat('dim', 3);
net.addLayer('concat2', concat2, {'relu_x6', 'unpool_x2'}, {'concat_x2'}, {});
% ----------------------------------------------
% Stage 3: 1st conv block : conv-batchnorm-relu
% ----------------------------------------------
conv13 = dagnn.Conv('size',[3,3,512,256], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv13', conv13, {'concat_x2'}, {'conv_x13'}, {'conv_f13','conv_b13'});
net.addLayer('bn13', dagnn.BatchNorm('numChannels', 256), {'conv_x13'}, {'bn_x13'}, {'bn13f', 'bn13b', 'bn13m'});
relu13 = dagnn.ReLU();
net.addLayer('relu13', relu13, {'bn_x13'}, {'relu_x13'}, {});
% ----------------------------------------------
% Stage 3: 2nd conv block : conv-batchnorm-relu
% ----------------------------------------------
conv14 = dagnn.Conv_original('size',[3,3,256,128], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv14', conv14, {'relu_x13'}, {'conv_x14'},{'conv_f14','conv_b14'});
net.addLayer('bn14', dagnn.BatchNorm('numChannels', 128), {'conv_x14'}, {'bn_x14'}, {'bn14f', 'bn14b', 'bn14m'});
relu14 = dagnn.ReLU();
net.addLayer('relu14', relu14, {'bn_x14'}, {'relu_x14'}, {});
% ----------------------------------------------
% Stage 3: unpooling :继续上采样
% ----------------------------------------------
Upsample3=dagnn.ConvTranspose('size',[3,3,128,128],'hasBias',false,'upsample',[2,2],'crop',[0,1,0,1]);
net.addLayer('unpool3', Upsample3,{'relu_x14'},{'unpool_x3'},{'f3'});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                              UPCONV STAGE II
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 2: concat block
% ----------------------------------------------
concat3 = dagnn.Concat('dim', 3);
net.addLayer('concat3', concat3, {'relu_x4', 'unpool_x3'}, {'concat_x3'}, {});
% ----------------------------------------------
% Stage 2: 1st conv block
% ----------------------------------------------
conv15 = dagnn.Conv('size',[3,3,256,128], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv15', conv15, {'concat_x3'}, {'conv_x15'}, {'conv_f15','conv_b15'});
net.addLayer('bn15', dagnn.BatchNorm('numChannels', 128), {'conv_x15'}, {'bn_x15'}, {'bn15f', 'bn15b', 'bn15m'});
relu15 = dagnn.ReLU();
net.addLayer('relu15', relu15, {'bn_x15'}, {'relu_x15'}, {});
% ----------------------------------------------
% Stage 2: 2nd conv block
% ----------------------------------------------
conv16 = dagnn.Conv('size',[3,3,128,64], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv16', conv16, {'relu_x15'}, {'conv_x16'},{'conv_f16','conv_b16'});
net.addLayer('bn16', dagnn.BatchNorm('numChannels', 64), {'conv_x16'}, {'bn_x16'}, {'bn16f', 'bn16b', 'bn16m'});
relu16 = dagnn.ReLU();
net.addLayer('relu16', relu16, {'bn_x16'}, {'relu_x16'}, {});
% ----------------------------------------------
% Stage 2: unpooling
% ----------------------------------------------
Upsample4=dagnn.ConvTranspose('size',[3,3,64,64],'hasBias',false,'upsample',[2,2],'crop',[0,1,0,1]);
net.addLayer('unpool4', Upsample4,{'relu_x16'},{'unpool_x4'},{'f4'});

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                              UPCONV STAGE I
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ----------------------------------------------
% Stage 1: concat block
% ----------------------------------------------
concat4 = dagnn.Concat('dim', 3);
net.addLayer('concat4', concat4, {'relu_x2', 'unpool_x4'}, {'concat_x4'}, {});
% ----------------------------------------------
% Stage 1: 1st conv block
% ----------------------------------------------
conv17 = dagnn.Conv('size',[3,3,128,64], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv17', conv17, {'concat_x4'}, {'conv_x17'}, {'conv_f17','conv_b17'});
net.addLayer('bn17', dagnn.BatchNorm('numChannels', 64), {'conv_x17'}, {'bn_x17'}, {'bn17f', 'bn17b', 'bn17m'});
relu17 = dagnn.ReLU();
net.addLayer('relu17', relu17, {'bn_x17'}, {'relu_x17'}, {});
% ----------------------------------------------
% Stage 1: 2nd conv block
% ----------------------------------------------
conv18 = dagnn.Conv('size',[3,3,64,64], 'pad', 1, 'stride', 1, 'hasBias', true);
net.addLayer('conv18', conv18, {'relu_x17'}, {'conv_x18'},{'conv_f18','conv_b18'});
net.addLayer('bn18', dagnn.BatchNorm('numChannels', 64), {'conv_x18'}, {'bn_x18'}, {'bn18f', 'bn18b', 'bn18m'});
relu18 = dagnn.ReLU();
net.addLayer('relu18', relu18, {'bn_x18'}, {'relu_x18'}, {});
% ----------------------------------------------
% Stage 0: Prediction block
% ----------------------------------------------
pred = dagnn.Conv('size',[1,1,64,1], 'pad', 0, 'stride', 1, 'hasBias', true);
net.addLayer('pred', pred, {'relu_x18'},{'Image_Pre'},{'pred_f1','pred_b1'});
SumBlock=dagnn.Sum();
net.addLayer('sum',SumBlock,{'Image_Pre','FBP'},{'Image'});


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%                                  LOSS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
SmoothBlock=dagnn.Smooth();
net.addLayer('Smooth', SmoothBlock, {'Image'}, {'loss'}) ;
PrjCompareBlock=dagnn.PrjCompare();
net.addLayer('PrjCompare', PrjCompareBlock, {'Image','data','Hsys','weights'}, {'loss2'}) ;
net.initParams() ;


end

3. 额外补充

关于使用MatConvnetNet中的dagnn.BatchNorm进行批归一化时,应该注意什么?

classdef BatchNorm < dagnn.ElementWise
  properties
    numChannels
    epsilon = 1e-5
    opts = {'NoCuDNN'} % ours seems slightly faster
  end

  properties (Transient)
    moments
  end

  methods
    function outputs = forward(obj, inputs, params)
      if strcmp(obj.net.mode, 'test')
        outputs{1} = vl_nnbnorm(inputs{1}, params{1}, params{2}, ...
                                'moments', params{3}, ...
                                'epsilon', obj.epsilon, ...
                                obj.opts{:}) ;
      else
        [outputs{1},obj.moments] = ...
            vl_nnbnorm(inputs{1}, params{1}, params{2}, ...
                       'epsilon', obj.epsilon, ...
                       obj.opts{:}) ;
      end
    end

    function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
      [derInputs{1}, derParams{1}, derParams{2}, derParams{3}] = ...
        vl_nnbnorm(inputs{1}, params{1}, params{2}, derOutputs{1}, ...
                   'epsilon', obj.epsilon, ...
                   'moments', obj.moments, ...
                   obj.opts{:}) ;
      obj.moments = [] ;
      % multiply the moments update by the number of images in the batch
      % this is required to make the update additive for subbatches
      % and will eventually be normalized away
      derParams{3} = derParams{3} * size(inputs{1},4) ;
    end

    % ---------------------------------------------------------------------
    function obj = BatchNorm(varargin)
      obj.load(varargin{:}) ;
    end

    function params = initParams(obj)
      params{1} = ones(obj.numChannels,1,'single') ;
      params{2} = zeros(obj.numChannels,1,'single') ;
      params{3} = zeros(obj.numChannels,2,'single') ;
    end

    function attach(obj, net, index)
      [email protected](obj, net, index) ;
      p = net.getParamIndex(net.layers(index).params{3}) ;
      net.params(p).trainMethod = 'average' ;
      net.params(p).learningRate = 0.1 ;
    end
  end
end

其实一般情况下,我们直接调用原函数就好了,内部的参数匹配机制会帮助我们识别通道数。

net.addLayers('bn1', dagnn,BatchNorm(), {'input'}, {'output'}, {'bn1f', 'bn1b', 'bn1m'})

如果从代码可读性角度考虑,也可以显式指明参数‘numChannels’:

net.addLayer('bn1', dagnn.BatchNorm('numChannels', 512), {'input'}, {'output'}, {'bn1f', 'bn1b', 'bn1m'});

猜你喜欢

转载自blog.csdn.net/shenziheng1/article/details/81332378