版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lanxueCC/article/details/53219666
本文主要解析caffe源码文件/src/caffe/layers/Solver.cpp,该文件主要定义caffe框架中优化函数类的基类。
Solver这个类实现了优化函数的封装,其中有一个protected的成员:shared_ptr net_;,这个成员是一个指向Net类型的智能指针(shared_ptr),Solver正是通过这个指针来和网络Net来交互并完成模型的优化。不同的子类分别实现了不同的优化方法:SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver和AdamSolver。默认使用的SGDSolver方法优化。
下面记录下我对Solver.cpp与Solver.hpp文件中关于Solver类的理解::::
Solver.hpp:::::
#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include <boost/function.hpp>
#include <string>
#include <vector>
#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"
namespace caffe {
/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* a client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
/*这个枚举类型定义获得外界信号的几种定义*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action. //忽略信号什么都不做
STOP = 1, // Stop training. snapshot_after_train controls whether a //停止训练
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training. //保存快照,继续训练
};
}
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
/*定义一个返回值为上面枚举,参数为空的函数指针ActionCallback类型*/
typedef boost::function<SolverAction::Enum()> ActionCallback;
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
/*声明solver类*/
template <typename Dtype>
class Solver {
public:
/*solver类构造函数*/
explicit Solver(const SolverParameter& param,
const Solver* root_solver = NULL);
explicit Solver(const string& param_file, const Solver* root_solver = NULL);
/*Solver初始化函数*/
void Init(const SolverParameter& param);
/*初始化训练网络*/
void InitTrainNet();
/*初始化测试的网络*/
void InitTestNets();
// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
/*传入信号传递函数指针*/
void SetActionFunction(ActionCallback func);
/*返回信号*/
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
/*训练的主函数*/
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
/*上面Solve调用的主函数,通过多次迭代实现训练*/
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
//存储函数实现如何存储solver到快照模型中
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
// 主要是基本的快照功能,存储学习的网络
void Snapshot();
//析构函数
virtual ~Solver() {}
/*返回配置参数变量*/
inline const SolverParameter& param() const { return param_; }
/*返回net*/
inline shared_ptr<Net<Dtype> > net() { return net_; }
/*返回指向测试的网络的指针容器test_nets_*/
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
/*返回迭代次数*/
int iter() { return iter_; }
// Invoked at specific points during an iteration
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}
void CheckSnapshotWritePermissions();
/**
* @brief Returns the solver type.
*/
virtual inline const char* type() const { return ""; }
protected:
// Make and apply the update value for the current iteration.
// 创建或更新当前迭代的值
virtual void ApplyUpdate() = 0;
//快照的名称
string SnapshotFilename(const string extension);
//将快照以二进制文件保存
string SnapshotToBinaryProto();
//将快照以HDF5文件保存
string SnapshotToHDF5();
// The test routine
// 测试所有网络
void TestAll();
// 测试单个网络
void Test(const int test_net_id = 0);
//处理快照的一些虚函数由具体的Solver类实现,例如:::SGDSolver
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
//未实现
void DisplayOutputBlobs(const int net_id);
//主要做Loss的平滑
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
//传播的配置的参数
SolverParameter param_;
//在测试的时候,需要迭代的次数,即test_iter* batchsize(测试集的)=测试集的大小,测试集batchsize可以在prototxt文件里设置
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_; /*指向Net类型的智能指针*/
vector<shared_ptr<Net<Dtype> > > test_nets_; /*保存测试的网络的指针的向量*/
vector<Callback*> callbacks_;
vector<Dtype> losses_; /*保存损失的向量用保存每次迭代的损失*/
Dtype smoothed_loss_; /*保存累加的损失也保存平滑的损失量*/
// The root solver that holds root nets (actually containing shared layers)
// in data parallelism
const Solver* const root_solver_;
// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
// 定义一个函数指针
ActionCallback action_request_function_;
// True iff a request to stop early was received.
// 这个变量表示是否需要提前退出
bool requested_early_exit_;
DISABLE_COPY_AND_ASSIGN(Solver);
};
/**
* @brief Solver that only computes gradients, used as worker
* for multi-GPU training.
*/
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> {
public:
explicit WorkerSolver(const SolverParameter& param,
const Solver<Dtype>* root_solver = NULL)
: Solver<Dtype>(param, root_solver) {}
protected:
void ApplyUpdate() {}
void SnapshotSolverState(const string& model_filename) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromBinaryProto(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromHDF5(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
};
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
Solver.cpp:::::
#include <cstdio>
#include <string>
#include <vector>
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
/*传入能传递消息的函数指针*/
template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}
/*返回传入的消息*/
template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}
/*Solver构造函数*/
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
Init(param);
}
/*Solver构造函数在这里实现*/
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m); /*读入网络参数*/
Init(param);
}
/*初始化被构造函数调用*/
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions(); /*检查下是否具有保存快照的权限*/
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed()); /*设置随机数种子 */
}
// Scaffolding code
InitTrainNet(); /*初始化训练网络*/
if (Caffe::root_solver()) {
InitTestNets(); /*初始化测试网络*/
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
/*初始化训练的网络*/
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();
const string& field_names = "net, net_param, train_net, train_net_param";
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;
NetParameter net_param;
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); /*读取solover.prototxt中的信息读入net_param*/
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());
}
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
}
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param itself;
// finally, merge in any NetState specified by the train_state (highest
// precedence).
NetState net_state;
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
net_.reset(new Net<Dtype>(net_param));/*调用Net的构造方法,重新构建网络*/
} else {
net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
}
}
/*初始化测试的网络*/
template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file;
CHECK_LE(num_generic_nets, 1)
<< "Both net_param and net_file may not be specified.";
const int num_test_net_params = param_.test_net_param_size();
const int num_test_net_files = param_.test_net_size();
const int num_test_nets = num_test_net_params + num_test_net_files;
if (num_generic_nets) {
CHECK_GE(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
} else {
CHECK_EQ(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
}
// If we have a generic net (specified by net or net_param, rather than
// test_net or test_net_param), we may have an unlimited number of actual
// test networks -- the actual number is given by the number of remaining
// test_iters after any test nets specified by test_net_param and/or test_net
// are evaluated.
const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
const int num_test_net_instances = num_test_nets + num_generic_net_instances;
if (param_.test_state_size()) {
CHECK_EQ(param_.test_state_size(), num_test_net_instances)
<< "test_state must be unspecified or specified once per test net.";
}
if (num_test_net_instances) {
CHECK_GT(param_.test_interval(), 0);
}
int test_net_id = 0;
vector<string> sources(num_test_net_instances);
vector<NetParameter> net_params(num_test_net_instances);
for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] = "test_net_param";
net_params[test_net_id].CopyFrom(param_.test_net_param(i));
}
for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
sources[test_net_id] = "test_net file: " + param_.test_net(i);
ReadNetParamsFromTextFileOrDie(param_.test_net(i),
&net_params[test_net_id]);
}
const int remaining_test_nets = param_.test_iter_size() - test_net_id;
if (has_net_param) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net_param";
net_params[test_net_id].CopyFrom(param_.net_param());
}
}
if (has_net_file) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net file: " + param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
}
}
test_nets_.resize(num_test_net_instances);
for (int i = 0; i < num_test_net_instances; ++i) {
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param
// itself; finally, merge in any NetState specified by the test_state
// (highest precedence).
NetState net_state;
net_state.set_phase(TEST);
net_state.MergeFrom(net_params[i].state());
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i));
}
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
if (Caffe::root_solver()) {
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
} else {
test_nets_[i].reset(new Net<Dtype>(net_params[i],
root_solver_->test_nets_[i].get()));
}
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
//设置开始的迭代次数(如果是从之前的snapshot恢复的,
//那iter_等于snapshot时的迭代次数)和结束的迭代次数
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
// 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,
// losses存储之前的average_loss个loss,smoothed_loss为最后要输出的均值
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = 0;
//迭代
while (iter_ < stop_iter) {
// zero-init the params
// 清空上一次所有参数的梯度
net_->ClearParamDiffs();
// 判断是否需要测试
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
//判断是否需要提前结束迭代
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
// 判断当前迭代次数是否需要显示loss等信息
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = 0;
// iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
// 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的
// 这个设置我的理解是在GPU的显存不够的时候使用,比如我本来想把batch_size设置为128,但是会out_of_memory,
// 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据
for (int i = 0; i < param_.iter_size(); ++i) {
// 调用了Net中的代码,主要完成了前向后向的计算,
// 前向用于计算模型的最终输出和Loss,后向用于
// 计算每一层网络和参数的梯度。
loss += net_->ForwardBackward();//这行代码通过Net类的net_指针调用其成员函数ForwardBackward()
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
// 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,
//如果已经存够了,则将之前的替换掉
/*
* 这个函数主要做Loss的平滑。由于Caffe的训练方式是SGD,我们无法把所有的数据同时
* 放入模型进行训练,那么部分数据产生的Loss就可能会和全样本的平均Loss不同,在必要
* 时候将Loss和历史过程中更新的Loss求平均就可以减少Loss的震荡问题。
*/
UpdateSmoothedLoss(loss, start_iter, average_loss);
//输出当前迭代的信息
if (display) {
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< ", loss = " << smoothed_loss_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现。
ApplyUpdate();
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;//迭代次数加1
// 调用GetRequestedAction,实际是通过action_request_function_函数指针调用之前设置好(通过`SetRequestedAction`)的
// signal_handler的`CheckForSignals`函数,这个函数的作用是
// 会根据之前是否遇到系统信号以及信号的类型和我们设置(或者默认)的方式返回处理的方式
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
// 判断当前迭代是否需要snapshot,如果request等于`SNAPSHOT`则也需要
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
// 如果request为`STOP`则修改`requested_early_exit_`为true,之后就会提前结束迭代
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
//检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)
CHECK(Caffe::root_solver());
// 然后输出learning policy(更新学习率的策略)
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
// requested_early_exit_一开始被赋值为false,也就是现在没有要求在优化结束前退出
// Initialize to false every time we start solving.
requested_early_exit_ = false;
//判断`resume_file`这个指针是否NULL,如果不是则需要从resume_file存储的路径里读取之前训练的状态
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
Step(param_.max_iter() - iter_);//这个函数执行了实际的逐步的迭代训练过程
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
//迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot
//这个可以在solver.prototxt里设置
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot(); //保存快照
}
// 如果在`Step`函数的迭代过程中遇到了系统信号,且我们的处理方式设置为`STOP`,
// 那么`requested_early_exit_`会被修改为true,迭代提前结束,输出相关信息
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
// 判断是否需要输出最后的loss
if (param_.display() && iter_ % param_.display() == 0) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss);
/*更新并且平滑损失*/
UpdateSmoothedLoss(loss, start_iter, average_loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
//判断是否需要最后Test
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}
/*按网络循环,每个网络调用Test函数*/
template <typename Dtype>
void Solver<Dtype>::TestAll() {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
/*具体每个网络的测试*/
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();/*获得信号*/
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) {
if (SolverAction::SNAPSHOT == request) { /*如果传入信号是保存快照,则调用Snapshot()函数保存快照*/
Snapshot();
} else if (SolverAction::STOP == request) { /*如果 是stop则退出*/
requested_early_exit_ = true;
}
request = GetRequestedAction(); /*不停接收信号*/
}
if (requested_early_exit_) {
// break out of test loop.
break;
}
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(&iter_loss); /*执行前向传播测试图片*/
if (param_.test_compute_loss()) {
loss += iter_loss; /*累加损失便于后续统计*/
}
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]);
test_score_output_id.push_back(j); /*保存结果*/
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k];
}
}
}
}
/*一些测试结果打印*/
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < test_score.size(); ++i) {
const int output_blob_index =
test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index];
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mean_score << loss_msg_stream.str();
}
}
/*保存快照函数*/
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
model_filename = SnapshotToBinaryProto();/*要么保存到二进制文件*/
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
model_filename = SnapshotToHDF5(); /*要么保存到HDF5文件*/
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename);
}
/*检查对应目录下是否有保存快照文件的权限*/
template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() {
if (Caffe::root_solver() && param_.snapshot()) {
CHECK(param_.has_snapshot_prefix())
<< "In solver params, snapshot is specified but snapshot_prefix is not";
string probe_filename = SnapshotFilename(".tempfile");
std::ofstream probe_ofs(probe_filename.c_str());
if (probe_ofs.good()) {
probe_ofs.close();
std::remove(probe_filename.c_str());
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
<< "that the directory exists and is writeable.";
}
}
}
/*生成快照的名称*/
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
/*快照以二进制proto文件形式保存*/
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel");
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param;
/*调用网络ToProto函数,再调用层的ToProto函数将每数据保存到proto对象中*/
net_->ToProto(&net_param, param_.snapshot_diff());
/*写到具体文件*/
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
string model_filename = SnapshotFilename(".caffemodel.h5");
LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
/*调用网络的ToHDF5函数,网络的ToHDF5函数再调用ToHDF5的库函数保存参数*/
net_->ToHDF5(model_filename, param_.snapshot_diff());
return model_filename;
}
//存储函数实现如何存储solver到快照模型中。
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
CHECK(Caffe::root_solver());
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
/*调用具体的Solver的RestoreSolverStateFromHDF5来实现, 从HDF5文件来保存快照*/
RestoreSolverStateFromHDF5(state_filename);
} else {
/*调用具体的Solver的RestoreSolverStateFromBinaryProto来实现, 从二进制文件来保存快照*/
RestoreSolverStateFromBinaryProto(state_filename);
}
}
/*更新平滑损失*/
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
int average_loss) {
if (losses_.size() < average_loss) {
losses_.push_back(loss);
int size = losses_.size();
smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss;
smoothed_loss_ += (loss - losses_[idx]) / average_loss;
losses_[idx] = loss;
}
}
INSTANTIATE_CLASS(Solver);
} // namespace caffe