Caffe源代码之SGDSolver更新代码-ApplyUpdate()

Caffe源代码之SGDSolver更新代码-ApplyUpdate()

在我之前的博客Solver中,说明了在Solver中怎样进行Step训练、测试、打印结果、保存中间模型以及恢复训练的。但是在Solver基类中,有一个纯虚函数ApplyUpdate(),该函数是用来进行权值更新的,不同的优化器有不同的优化方法
(1)最基础的优化器,随机梯度下降法,单纯的梯度的负方向乘以学习率,作为权重的更新量;
(2)带动量的随机梯度下降法,累计之前的动量和当前的梯度,形成当前时刻的动量,更新权值;
(3)历史梯度平方的梯度下降方法,Adagrad,对当前各个权值梯度进行放缩,其放缩值为当前权值的历史梯度二范数
(4)历史梯度平滑的梯度下降方法,RMSprop,将历史梯度的二范数的平方与当前梯度的平方相加时,分配不一样的权值;
(5)Adam方法是RMSprop和动量法的结合,采用梯度的一阶矩估计和二阶矩估计,进行校正后,更新权重。
以上是目前比较常用的,优化器。具体选取哪个优化器根据实际情况来。根据笔者的个人经验,自然图像数据集比较大的情况,通常是采用,带动量的随机梯度下降法;如果是强化学习(DQN)之类的常用RMSProp(Adam没用过)对于医学图像来说,如果在自然图像进行预训练,则采用Adam

ApplyUpdate()源代码

下面以SGDSolver为例进行说明

//代码相对来说,比较简单,主要完成
//第一,更新学习率
//第二,单次迭代次数平均
//第三,权值正则化
//第四,更新权值

//对于第二点,在Caffe中,单次迭代可能包含多次ForwardBackward(),由iter_size决定
//ForwardBackward()进行一次前向和反向传递,其梯度保存在每一个Blob的diff数据块中。
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
  Dtype rate = GetLearningRate(); //获取当前迭代次数的学习率
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
    LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_
        << ", lr = " << rate;
  }
  ClipGradients();//钳制梯度,避免梯度过大
  //大循环,遍历net中可学习的参量
  for (int param_id = 0; param_id < this->net_->learnable_params().size();// 更新每一个数据块的权重
       ++param_id) {
    Normalize(param_id);//单次迭代归一化
    Regularize(param_id);//正则化
    ComputeUpdateValue(param_id, rate);//加上学习率
  }
  this->net_->Update();//进行更新的直接调用函数 data = data - diff

  // Increment the internal iter_ counter -- its value should always indicate
  // the number of times the weights have been updated.
  ++this->iter_;
}

GetLearningRate()获取当前全局学习率

//更新学习率的策略,有很多,比如指数衰减的,二项式衰减的,
//固定的,线性衰减的,倒数衰减的,Sigmoid衰减的
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
  Dtype rate;
  const string& lr_policy = this->param_.lr_policy();
  if (lr_policy == "fixed") {
    rate = this->param_.base_lr();
  } else if (lr_policy == "step") {
    CHECK_GT(this->param_.stepsize(), 0);
    this->current_step_ = this->iter_ / this->param_.stepsize();
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() *
        pow(this->param_.gamma(), this->current_step_);
  } else if (lr_policy == "exp") {
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
  } else if (lr_policy == "inv") {
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() *
        pow(Dtype(1) + this->param_.gamma() * this->iter_,
            - this->param_.power());
  } else if (lr_policy == "multistep") {
    if (this->current_step_ < this->param_.stepvalue_size() &&
          this->iter_ >= this->param_.stepvalue(this->current_step_)) {
      this->current_step_++;
      LOG(INFO) << "MultiStep Status: Iteration " <<
      this->iter_ << ", step = " << this->current_step_;
    }
    CHECK_GE(this->param_.gamma(), 0);
    rate = this->param_.base_lr() *
        pow(this->param_.gamma(), this->current_step_);
  } else if (lr_policy == "poly") {
    rate = this->param_.base_lr() * pow(Dtype(1.) -
        (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
        this->param_.power());
  } else if (lr_policy == "sigmoid") {
    CHECK_GE(this->param_.gamma(), 0);
    CHECK_GT(this->param_.stepsize(), 0);
    rate = this->param_.base_lr() * (Dtype(1.) /
        (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
          Dtype(this->param_.stepsize())))));
  } else {
    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
  }
  return rate;
}

Normalize(param_id)单次迭代归一化

template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
//如果是刚开始训练,不用归一化
  if (this->param_.iter_size() == 1) { return; }
  // Scale gradient to counterbalance accumulation.
  //获取网络可学习参量的指针容器
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  //
  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();//单次迭代的前向传播次数倒数
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    caffe_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_cpu_diff());//进行放缩,获取平均前向传播的梯度
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_gpu_diff());
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

Regularize正则化

template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_weight_decay =
      this->net_->params_weight_decay();
//获取全局权重衰减系数
  Dtype weight_decay = this->param_.weight_decay();
//获取正则化方式:L1 L2
  string regularization_type = this->param_.regularization_type();
  //获取当前层局部衰减梯度
  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    if (local_decay) {
      if (regularization_type == "L2") {
        // add weight decay
        //更新到对应Blob 的diff域中
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else if (regularization_type == "L1") {
      //L1同理,获取符号,然后 更新到对应Blob 的diff域中
        caffe_cpu_sign(net_params[param_id]->count(),
            net_params[param_id]->cpu_data(),
            temp_[param_id]->mutable_cpu_data());
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    if (local_decay) {
      if (regularization_type == "L2") {
        // add weight decay
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());
      } else if (regularization_type == "L1") {
        caffe_gpu_sign(net_params[param_id]->count(),
            net_params[param_id]->gpu_data(),
            temp_[param_id]->mutable_gpu_data());
        caffe_gpu_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->gpu_data(),
            net_params[param_id]->mutable_gpu_diff());
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

ComputeUpdateValue

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_lr = this->net_->params_lr();
  Dtype momentum = this->param_.momentum();//获取动量
  Dtype local_rate = rate * net_params_lr[param_id];//计算局部学习率
  // Compute the update to history, then copy it to the parameter diff.
  switch (Caffe::mode()) {
  case Caffe::CPU: {
  //更新获取,当前迭代的动量
  //计算公式:hisory_ = local_rate * diff + moment * history_
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());
//复制到当前动量到diff域中
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    sgd_update_gpu(net_params[param_id]->count(),
        net_params[param_id]->mutable_gpu_diff(),
        history_[param_id]->mutable_gpu_data(),
        momentum, local_rate);
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

猜你喜欢

转载自blog.csdn.net/charel_chen/article/details/81262186