caffe源码 之 Solver类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lanxueCC/article/details/53219666

本文主要解析caffe源码文件/src/caffe/layers/Solver.cpp,该文件主要定义caffe框架中优化函数类的基类。

Solver这个类实现了优化函数的封装,其中有一个protected的成员:shared_ptr net_;,这个成员是一个指向Net类型的智能指针(shared_ptr),Solver正是通过这个指针来和网络Net来交互并完成模型的优化。不同的子类分别实现了不同的优化方法:SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver和AdamSolver。默认使用的SGDSolver方法优化。

下面记录下我对Solver.cpp与Solver.hpp文件中关于Solver类的理解::::

Solver.hpp:::::

#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"

namespace caffe {

/**
  * @brief Enumeration of actions that a client of the Solver may request by
  * implementing the Solver's action request function, which a
  * a client may optionally provide in order to request early termination
  * or saving a snapshot without exiting. In the executable caffe, this
  * mechanism is used to allow the snapshot to be saved when stopping
  * execution with a SIGINT (Ctrl-C).
  */
  /*这个枚举类型定义获得外界信号的几种定义*/
  namespace SolverAction {
    enum Enum {
      NONE = 0,  // Take no special action.   //忽略信号什么都不做
      STOP = 1,  // Stop training. snapshot_after_train controls whether a  //停止训练
                 // snapshot is created.
      SNAPSHOT = 2  // Take a snapshot, and keep training.  //保存快照,继续训练
    };
  }

/**
 * @brief Type of a function that returns a Solver Action enumeration.
 */
 /*定义一个返回值为上面枚举,参数为空的函数指针ActionCallback类型*/
typedef boost::function<SolverAction::Enum()> ActionCallback;

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * Requires implementation of ApplyUpdate to compute a parameter update
 * given the current state of the Net parameters.
 */
 /*声明solver类*/
template <typename Dtype>
class Solver {
 public:
    /*solver类构造函数*/
  explicit Solver(const SolverParameter& param,
      const Solver* root_solver = NULL);
  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
  /*Solver初始化函数*/
  void Init(const SolverParameter& param);
  /*初始化训练网络*/
  void InitTrainNet();
  /*初始化测试的网络*/
  void InitTestNets();

  // Client of the Solver optionally may call this in order to set the function
  // that the solver uses to see what action it should take (e.g. snapshot or
  // exit training early).
  /*传入信号传递函数指针*/
  void SetActionFunction(ActionCallback func);
  /*返回信号*/
  SolverAction::Enum GetRequestedAction();
  // The main entry of the solver function. In default, iter will be zero. Pass
  // in a non-zero iter number to resume training for a pre-trained net.
  /*训练的主函数*/
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
  /*上面Solve调用的主函数,通过多次迭代实现训练*/
  void Step(int iters);
  // The Restore method simply dispatches to one of the
  // RestoreSolverStateFrom___ protected methods. You should implement these
  // methods to restore the state from the appropriate snapshot type.
  //存储函数实现如何存储solver到快照模型中
  void Restore(const char* resume_file);
  // The Solver::Snapshot function implements the basic snapshotting utility
  // that stores the learned net. You should implement the SnapshotSolverState()
  // function that produces a SolverState protocol buffer that needs to be
  // written to disk together with the learned net.
  // 主要是基本的快照功能,存储学习的网络  
  void Snapshot();
  //析构函数
  virtual ~Solver() {}
  /*返回配置参数变量*/
  inline const SolverParameter& param() const { return param_; }
  /*返回net*/
  inline shared_ptr<Net<Dtype> > net() { return net_; }
  /*返回指向测试的网络的指针容器test_nets_*/
  inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
    return test_nets_;
  }
  /*返回迭代次数*/
  int iter() { return iter_; }

  // Invoked at specific points during an iteration
  class Callback {
   protected:
    virtual void on_start() = 0;
    virtual void on_gradients_ready() = 0;

    template <typename T>
    friend class Solver;
  };
  const vector<Callback*>& callbacks() const { return callbacks_; }
  void add_callback(Callback* value) {
    callbacks_.push_back(value);
  }

  void CheckSnapshotWritePermissions();
  /**
   * @brief Returns the solver type.
   */
  virtual inline const char* type() const { return ""; }

 protected:
  // Make and apply the update value for the current iteration.
  // 创建或更新当前迭代的值
  virtual void ApplyUpdate() = 0;
  //快照的名称
  string SnapshotFilename(const string extension);
  //将快照以二进制文件保存
  string SnapshotToBinaryProto();
  //将快照以HDF5文件保存
  string SnapshotToHDF5();
  // The test routine
  // 测试所有网络
  void TestAll();
  // 测试单个网络
  void Test(const int test_net_id = 0);
  //处理快照的一些虚函数由具体的Solver类实现,例如:::SGDSolver
  virtual void SnapshotSolverState(const string& model_filename) = 0;
  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
  //未实现
  void DisplayOutputBlobs(const int net_id);
  //主要做Loss的平滑
  void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);

  //传播的配置的参数
  SolverParameter param_;
  //在测试的时候,需要迭代的次数,即test_iter* batchsize(测试集的)=测试集的大小,测试集batchsize可以在prototxt文件里设置 
  int iter_;
  int current_step_;
  shared_ptr<Net<Dtype> > net_;  /*指向Net类型的智能指针*/
  vector<shared_ptr<Net<Dtype> > > test_nets_;  /*保存测试的网络的指针的向量*/
  vector<Callback*> callbacks_;
  vector<Dtype> losses_;    /*保存损失的向量用保存每次迭代的损失*/
  Dtype smoothed_loss_;     /*保存累加的损失也保存平滑的损失量*/

  // The root solver that holds root nets (actually containing shared layers)
  // in data parallelism
  const Solver* const root_solver_;

  // A function that can be set by a client of the Solver to provide indication
  // that it wants a snapshot saved and/or to exit early.
  // 定义一个函数指针
  ActionCallback action_request_function_;

  // True iff a request to stop early was received.
  // 这个变量表示是否需要提前退出
  bool requested_early_exit_;

  DISABLE_COPY_AND_ASSIGN(Solver);
};

/**
 * @brief Solver that only computes gradients, used as worker
 *        for multi-GPU training.
 */
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> {
 public:
  explicit WorkerSolver(const SolverParameter& param,
      const Solver<Dtype>* root_solver = NULL)
      : Solver<Dtype>(param, root_solver) {}

 protected:
  void ApplyUpdate() {}
  void SnapshotSolverState(const string& model_filename) {
    LOG(FATAL) << "Should not be called on worker solver.";
  }
  void RestoreSolverStateFromBinaryProto(const string& state_file) {
    LOG(FATAL) << "Should not be called on worker solver.";
  }
  void RestoreSolverStateFromHDF5(const string& state_file) {
    LOG(FATAL) << "Should not be called on worker solver.";
  }
};

}  // namespace caffe

#endif  // CAFFE_SOLVER_HPP_

Solver.cpp:::::

#include <cstdio>

#include <string>
#include <vector>

#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"

namespace caffe {

/*传入能传递消息的函数指针*/
template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
  action_request_function_ = func;
}

/*返回传入的消息*/
template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
  if (action_request_function_) {
    // If the external request function has been set, call it.
    return action_request_function_();
  }
  return SolverAction::NONE;
}

/*Solver构造函数*/
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  Init(param);
}

/*Solver构造函数在这里实现*/
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
    : net_(), callbacks_(), root_solver_(root_solver),
      requested_early_exit_(false) {
  SolverParameter param;
  ReadSolverParamsFromTextFileOrDie(param_file, &param);  /*读入网络参数*/
  Init(param);
}

/*初始化被构造函数调用*/
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
  CHECK(Caffe::root_solver() || root_solver_)
      << "root_solver_ needs to be set for all non-root solvers";
  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
    << std::endl << param.DebugString();
  param_ = param;
  CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  CheckSnapshotWritePermissions();  /*检查下是否具有保存快照的权限*/
  if (Caffe::root_solver() && param_.random_seed() >= 0) {
    Caffe::set_random_seed(param_.random_seed()); /*设置随机数种子 */
  }
  // Scaffolding code
  InitTrainNet();   /*初始化训练网络*/
  if (Caffe::root_solver()) {
    InitTestNets(); /*初始化测试网络*/
    LOG(INFO) << "Solver scaffolding done.";
  }
  iter_ = 0;
  current_step_ = 0;
}

/*初始化训练的网络*/
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
  const int num_train_nets = param_.has_net() + param_.has_net_param() +
      param_.has_train_net() + param_.has_train_net_param();
  const string& field_names = "net, net_param, train_net, train_net_param";
  CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
      << "using one of these fields: " << field_names;
  CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
      << "one of these fields specifying a train_net: " << field_names;
  NetParameter net_param;
  if (param_.has_train_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in train_net_param.";
    net_param.CopyFrom(param_.train_net_param());
  } else if (param_.has_train_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from train_net file: " << param_.train_net();
    ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);  /*读取solover.prototxt中的信息读入net_param*/
  }
  if (param_.has_net_param()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net specified in net_param.";
    net_param.CopyFrom(param_.net_param());
  }
  if (param_.has_net()) {
    LOG_IF(INFO, Caffe::root_solver())
        << "Creating training net from net file: " << param_.net();
    ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
  }
  // Set the correct NetState.  We start with the solver defaults (lowest
  // precedence); then, merge in any NetState specified by the net_param itself;
  // finally, merge in any NetState specified by the train_state (highest
  // precedence).
  NetState net_state;
  net_state.set_phase(TRAIN);
  net_state.MergeFrom(net_param.state());
  net_state.MergeFrom(param_.train_state());
  net_param.mutable_state()->CopyFrom(net_state);
  if (Caffe::root_solver()) {
    net_.reset(new Net<Dtype>(net_param));/*调用Net的构造方法,重新构建网络*/
  } else {
    net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
  }
}

/*初始化测试的网络*/
template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
  CHECK(Caffe::root_solver());
  const bool has_net_param = param_.has_net_param();
  const bool has_net_file = param_.has_net();
  const int num_generic_nets = has_net_param + has_net_file;
  CHECK_LE(num_generic_nets, 1)
      << "Both net_param and net_file may not be specified.";
  const int num_test_net_params = param_.test_net_param_size();
  const int num_test_net_files = param_.test_net_size();
  const int num_test_nets = num_test_net_params + num_test_net_files;
  if (num_generic_nets) {
      CHECK_GE(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  } else {
      CHECK_EQ(param_.test_iter_size(), num_test_nets)
          << "test_iter must be specified for each test network.";
  }
  // If we have a generic net (specified by net or net_param, rather than
  // test_net or test_net_param), we may have an unlimited number of actual
  // test networks -- the actual number is given by the number of remaining
  // test_iters after any test nets specified by test_net_param and/or test_net
  // are evaluated.
  const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
  const int num_test_net_instances = num_test_nets + num_generic_net_instances;
  if (param_.test_state_size()) {
    CHECK_EQ(param_.test_state_size(), num_test_net_instances)
        << "test_state must be unspecified or specified once per test net.";
  }
  if (num_test_net_instances) {
    CHECK_GT(param_.test_interval(), 0);
  }
  int test_net_id = 0;
  vector<string> sources(num_test_net_instances);
  vector<NetParameter> net_params(num_test_net_instances);
  for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net_param";
      net_params[test_net_id].CopyFrom(param_.test_net_param(i));
  }
  for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
      sources[test_net_id] = "test_net file: " + param_.test_net(i);
      ReadNetParamsFromTextFileOrDie(param_.test_net(i),
          &net_params[test_net_id]);
  }
  const int remaining_test_nets = param_.test_iter_size() - test_net_id;
  if (has_net_param) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net_param";
      net_params[test_net_id].CopyFrom(param_.net_param());
    }
  }
  if (has_net_file) {
    for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
      sources[test_net_id] = "net file: " + param_.net();
      ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
    }
  }
  test_nets_.resize(num_test_net_instances);
  for (int i = 0; i < num_test_net_instances; ++i) {
    // Set the correct NetState.  We start with the solver defaults (lowest
    // precedence); then, merge in any NetState specified by the net_param
    // itself; finally, merge in any NetState specified by the test_state
    // (highest precedence).
    NetState net_state;
    net_state.set_phase(TEST);
    net_state.MergeFrom(net_params[i].state());
    if (param_.test_state_size()) {
      net_state.MergeFrom(param_.test_state(i));
    }
    net_params[i].mutable_state()->CopyFrom(net_state);
    LOG(INFO)
        << "Creating test net (#" << i << ") specified by " << sources[i];
    if (Caffe::root_solver()) {
      test_nets_[i].reset(new Net<Dtype>(net_params[i]));
    } else {
      test_nets_[i].reset(new Net<Dtype>(net_params[i],
          root_solver_->test_nets_[i].get()));
    }
    test_nets_[i]->set_debug_info(param_.debug_info());
  }
}

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
  //设置开始的迭代次数(如果是从之前的snapshot恢复的,
  //那iter_等于snapshot时的迭代次数)和结束的迭代次数
  const int start_iter = iter_;
  const int stop_iter = iter_ + iters;

  // 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,
  // losses存储之前的average_loss个loss,smoothed_loss为最后要输出的均值
  int average_loss = this->param_.average_loss();
  losses_.clear();
  smoothed_loss_ = 0;

  //迭代
  while (iter_ < stop_iter) {
    // zero-init the params
    // 清空上一次所有参数的梯度
    net_->ClearParamDiffs();

     // 判断是否需要测试
    if (param_.test_interval() && iter_ % param_.test_interval() == 0
        && (iter_ > 0 || param_.test_initialization())
        && Caffe::root_solver()) {
      TestAll();

      //判断是否需要提前结束迭代
      if (requested_early_exit_) {
        // Break out of the while loop because stop was requested while testing.
        break;
      }
    }

    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }

    // 判断当前迭代次数是否需要显示loss等信息
    const bool display = param_.display() && iter_ % param_.display() == 0;
    net_->set_debug_info(display && param_.debug_info());
    // accumulate the loss and gradient
    Dtype loss = 0;

    // iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
    // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的
    // 这个设置我的理解是在GPU的显存不够的时候使用,比如我本来想把batch_size设置为128,但是会out_of_memory,
    // 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据
    for (int i = 0; i < param_.iter_size(); ++i) {
      // 调用了Net中的代码,主要完成了前向后向的计算,
      // 前向用于计算模型的最终输出和Loss,后向用于
      // 计算每一层网络和参数的梯度。
      loss += net_->ForwardBackward();//这行代码通过Net类的net_指针调用其成员函数ForwardBackward()
    }
    loss /= param_.iter_size();

    // average the loss across iterations for smoothed reporting
    // 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,
    //如果已经存够了,则将之前的替换掉
    /*
     * 这个函数主要做Loss的平滑。由于Caffe的训练方式是SGD,我们无法把所有的数据同时
     * 放入模型进行训练,那么部分数据产生的Loss就可能会和全样本的平均Loss不同,在必要
     * 时候将Loss和历史过程中更新的Loss求平均就可以减少Loss的震荡问题。
     */
    UpdateSmoothedLoss(loss, start_iter, average_loss);

    //输出当前迭代的信息
    if (display) {
      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
          << ", loss = " << smoothed_loss_;
      const vector<Blob<Dtype>*>& result = net_->output_blobs();
      int score_index = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        const string& output_name =
            net_->blob_names()[net_->output_blob_indices()[j]];
        const Dtype loss_weight =
            net_->blob_loss_weights()[net_->output_blob_indices()[j]];
        for (int k = 0; k < result[j]->count(); ++k) {
          ostringstream loss_msg_stream;
          if (loss_weight) {
            loss_msg_stream << " (* " << loss_weight
                            << " = " << loss_weight * result_vec[k] << " loss)";
          }
          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
              << score_index++ << ": " << output_name << " = "
              << result_vec[k] << loss_msg_stream.str();
        }
      }
    }
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }

    // 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现。
    ApplyUpdate();

    // Increment the internal iter_ counter -- its value should always indicate
    // the number of times the weights have been updated.
    ++iter_;//迭代次数加1

    // 调用GetRequestedAction,实际是通过action_request_function_函数指针调用之前设置好(通过`SetRequestedAction`)的
    // signal_handler的`CheckForSignals`函数,这个函数的作用是
    // 会根据之前是否遇到系统信号以及信号的类型和我们设置(或者默认)的方式返回处理的方式
    SolverAction::Enum request = GetRequestedAction();

    // Save a snapshot if needed.
    // 判断当前迭代是否需要snapshot,如果request等于`SNAPSHOT`则也需要
    if ((param_.snapshot()
         && iter_ % param_.snapshot() == 0
         && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {
      Snapshot();
    }

    // 如果request为`STOP`则修改`requested_early_exit_`为true,之后就会提前结束迭代
    if (SolverAction::STOP == request) {
      requested_early_exit_ = true;
      // Break out of training loop.
      break;
    }
  }
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {

  //检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)
  CHECK(Caffe::root_solver()); 

  // 然后输出learning policy(更新学习率的策略)
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

  // requested_early_exit_一开始被赋值为false,也就是现在没有要求在优化结束前退出
  // Initialize to false every time we start solving.
  requested_early_exit_ = false;

  //判断`resume_file`这个指针是否NULL,如果不是则需要从resume_file存储的路径里读取之前训练的状态
  if (resume_file) {
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);
  }

  // For a network that is trained by the solver, no bottom or top vecs
  // should be given, and we will just provide dummy vecs.
  int start_iter = iter_;
  Step(param_.max_iter() - iter_);//这个函数执行了实际的逐步的迭代训练过程
  // If we haven't already, save a snapshot after optimization, unless
  // overridden by setting snapshot_after_train := false
  //迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot
  //这个可以在solver.prototxt里设置
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    Snapshot();  //保存快照
  }

  // 如果在`Step`函数的迭代过程中遇到了系统信号,且我们的处理方式设置为`STOP`,
  // 那么`requested_early_exit_`会被修改为true,迭代提前结束,输出相关信息
  if (requested_early_exit_) {
    LOG(INFO) << "Optimization stopped early.";
    return;
  }
  // After the optimization is done, run an additional train and test pass to
  // display the train and test loss/outputs if appropriate (based on the
  // display and test_interval settings, respectively).  Unlike in the rest of
  // training, for the train net we only run a forward pass as we've already
  // updated the parameters "max_iter" times -- this final pass is only done to
  // display the loss, which is computed in the forward pass.
  // 判断是否需要输出最后的loss
  if (param_.display() && iter_ % param_.display() == 0) {
    int average_loss = this->param_.average_loss();
    Dtype loss;
    net_->Forward(&loss);

    /*更新并且平滑损失*/
    UpdateSmoothedLoss(loss, start_iter, average_loss);

    LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
  }

  //判断是否需要最后Test
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
  }
  LOG(INFO) << "Optimization Done.";
}

/*按网络循环,每个网络调用Test函数*/
template <typename Dtype>
void Solver<Dtype>::TestAll() {
  for (int test_net_id = 0;
       test_net_id < test_nets_.size() && !requested_early_exit_;
       ++test_net_id) {
    Test(test_net_id);
  }
}

/*具体每个网络的测试*/
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
  CHECK(Caffe::root_solver());
  LOG(INFO) << "Iteration " << iter_
            << ", Testing net (#" << test_net_id << ")";
  CHECK_NOTNULL(test_nets_[test_net_id].get())->
      ShareTrainedLayersWith(net_.get());
  vector<Dtype> test_score;
  vector<int> test_score_output_id;
  const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
  Dtype loss = 0;
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    SolverAction::Enum request = GetRequestedAction();/*获得信号*/
    // Check to see if stoppage of testing/training has been requested.
    while (request != SolverAction::NONE) {
        if (SolverAction::SNAPSHOT == request) {   /*如果传入信号是保存快照,则调用Snapshot()函数保存快照*/
          Snapshot();
        } else if (SolverAction::STOP == request) { /*如果 是stop则退出*/
          requested_early_exit_ = true;
        }
        request = GetRequestedAction();  /*不停接收信号*/
    }
    if (requested_early_exit_) {
      // break out of test loop.
      break;
    }

    Dtype iter_loss;
    const vector<Blob<Dtype>*>& result =
        test_net->Forward(&iter_loss);  /*执行前向传播测试图片*/
    if (param_.test_compute_loss()) {
      loss += iter_loss;   /*累加损失便于后续统计*/
    }
    if (i == 0) {
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score.push_back(result_vec[k]);
          test_score_output_id.push_back(j);   /*保存结果*/
        }
      }
    } else {
      int idx = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        for (int k = 0; k < result[j]->count(); ++k) {
          test_score[idx++] += result_vec[k];
        }
      }
    }
  }
  /*一些测试结果打印*/
  if (requested_early_exit_) {
    LOG(INFO)     << "Test interrupted.";
    return;
  }
  if (param_.test_compute_loss()) {
    loss /= param_.test_iter(test_net_id);
    LOG(INFO) << "Test loss: " << loss;
  }
  for (int i = 0; i < test_score.size(); ++i) {
    const int output_blob_index =
        test_net->output_blob_indices()[test_score_output_id[i]];
    const string& output_name = test_net->blob_names()[output_blob_index];
    const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
    ostringstream loss_msg_stream;
    const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
              << mean_score << loss_msg_stream.str();
  }
}

/*保存快照函数*/
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
  CHECK(Caffe::root_solver());
  string model_filename;
  switch (param_.snapshot_format()) {
  case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
    model_filename = SnapshotToBinaryProto();/*要么保存到二进制文件*/
    break;
  case caffe::SolverParameter_SnapshotFormat_HDF5:
    model_filename = SnapshotToHDF5();   /*要么保存到HDF5文件*/
    break;
  default:
    LOG(FATAL) << "Unsupported snapshot format.";
  }

  SnapshotSolverState(model_filename);
}


/*检查对应目录下是否有保存快照文件的权限*/
template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() {
  if (Caffe::root_solver() && param_.snapshot()) {
    CHECK(param_.has_snapshot_prefix())
        << "In solver params, snapshot is specified but snapshot_prefix is not";
    string probe_filename = SnapshotFilename(".tempfile");
    std::ofstream probe_ofs(probe_filename.c_str());
    if (probe_ofs.good()) {
      probe_ofs.close();
      std::remove(probe_filename.c_str());
    } else {
      LOG(FATAL) << "Cannot write to snapshot prefix '"
          << param_.snapshot_prefix() << "'.  Make sure "
          << "that the directory exists and is writeable.";
    }
  }
}

/*生成快照的名称*/
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
  return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
    + extension;
}

/*快照以二进制proto文件形式保存*/
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
  string model_filename = SnapshotFilename(".caffemodel");
  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
  NetParameter net_param;
  /*调用网络ToProto函数,再调用层的ToProto函数将每数据保存到proto对象中*/
  net_->ToProto(&net_param, param_.snapshot_diff());
  /*写到具体文件*/
  WriteProtoToBinaryFile(net_param, model_filename);
  return model_filename;
}

template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
  string model_filename = SnapshotFilename(".caffemodel.h5");
  LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
  /*调用网络的ToHDF5函数,网络的ToHDF5函数再调用ToHDF5的库函数保存参数*/
  net_->ToHDF5(model_filename, param_.snapshot_diff());
  return model_filename;
}

//存储函数实现如何存储solver到快照模型中。
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
  CHECK(Caffe::root_solver());
  string state_filename(state_file);
  if (state_filename.size() >= 3 &&
      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
    /*调用具体的Solver的RestoreSolverStateFromHDF5来实现, 从HDF5文件来保存快照*/
    RestoreSolverStateFromHDF5(state_filename);
  } else {
    /*调用具体的Solver的RestoreSolverStateFromBinaryProto来实现, 从二进制文件来保存快照*/
    RestoreSolverStateFromBinaryProto(state_filename);  
  }
}

/*更新平滑损失*/
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
    int average_loss) {
  if (losses_.size() < average_loss) {
    losses_.push_back(loss);
    int size = losses_.size();
    smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
  } else {
    int idx = (iter_ - start_iter) % average_loss;
    smoothed_loss_ += (loss - losses_[idx]) / average_loss;
    losses_[idx] = loss;
  }
}

INSTANTIATE_CLASS(Solver);

}  // namespace caffe

猜你喜欢

转载自blog.csdn.net/lanxueCC/article/details/53219666