Caffe可视化MNIST错误识别样本

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/tianrolin/article/details/53542050

当基于LeNet网络的MNIST手写库训练完毕后,测试样本精度能达到99%以上。但是那错误的不到百分之一的样本是什么样子的呢?我们怎么才能把这些识别错误的样本可视化出来呢?

1.将测试错误样本打印出来

当运行测试时,最后的输出层为AccuracyLayer层。AccuracyLayer对前一层全连接层ip2的10个神经元输出结果进行排序,然后将最大值所对应的神经元序号与标签label进行比较,相等则判定预测正确;否则判定预测错误。所以,我们在这里添加输出判定错误的代码

void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
  ...
  // check if true label is in top k predictions
  for (int k = 0; k < top_k_; k++) {
    if (bottom_data_vector[k].second == label_value) {
      // 预测正确
      ...
    }
    else
    {
      // 预测错误
      // index为batch中的图片序号(0~99),label为标签值,output为预测值
      LOG(INFO) << "index:" << i << " label:" << label_value << " output:" << bottom_data_vector[k].second;
    }
  }
}

这样我们就知道在一个batch中哪些图片被预测错误,以及它的标签值和预测值。测试样本总共有10000个,分为100个batch,每个batch大小为100个,所以我们还需要输出每个batch的序号。跳转到Slover::Test()函数中,

void Solver<Dtype>::Test(const int test_net_id) {
  ...
  for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
    // 输出batch序号
    LOG(INFO) << "batch:" << i;
  }
}

2.将日志输出至文件

Caffe框架使用google开源日志glog,用来在控制台输出程序运行日志。其实日志不仅输出在控制台,glog还帮我们在硬盘临时文件中也保存了一份日志。它是如何做到的呢?在Caffe框架初始化时调用了初始化函数GlobalInit(),

void GlobalInit(int* pargc, char*** pargv) {
    ...
    // Google logging.
    ::google::InitGoogleLogging(*(pargv)[0]);
    ...
}

其中的::google::InitGoogleLogging()函数就是定位硬盘文件输出日志,参数*(pargv)[0]为caffe.exe的路径名,但是输出目录为系统的临时目录。你可以在临时目录下找到类似这样随机名字的文件:
caffe.exe.ComputeName.UserName.log.INFO.Data-194350.18356
这个文件可以直接用文本编辑器打开查看,内容和在控制台输出的信息一模一样。

3.用Matlab将错误样本可视化

下面我们来写段Matlab代码,用来读取上面的日志文件,以及将MNIST数据库可视化。

clear;clc;close all;

fid = fopen('caffe.exe.txt');   % 替换为日志文件名
tline = fgetl(fid);

C = [];     % 定义空矩阵用来存放结果

while ischar(tline)
    if ~isempty(strfind(tline, 'batch:'))  % 查找字符串
        indexline = fgetl(fid);
        if ~isempty(strfind(indexline, 'batch:'))
            tline = indexline;
        elseif isempty(strfind(indexline, 'index:'))
            tline = indexline;
        else
            % 在tline中解析batch
            idx1 = strfind(tline, 'batch:');
            batch = str2num(tline(idx1 + 6 : length(tline)));
            % 在indexline中解析index,label,output
            idx2 = strfind(indexline, 'index:');
            idx3 = strfind(indexline, 'label:');
            idx4 = strfind(indexline, 'output:');
            index = str2num(indexline(idx2 + 6 : idx3 - 2));
            label = str2num(indexline(idx3 + 6 : idx4 - 2));
            output = str2num(indexline(idx4 + 7 : length(indexline)));
            % 添加到数组中
            C = [C; batch, index, label, output];
        end
    else
        tline = fgetl(fid);
    end
end

fclose(fid);

% 可视化部分
image_file_name = 't10k-images.idx3-ubyte';
fid = fopen(image_file_name);
images_data = fread(fid, 'uint8');
fclose(fid);

images_data = images_data(17:end);
image_buffer = zeros(28, 28);

for k = 1:1:size(C,1)
    figure(size(C,1));
    index = C(k,1) * 100 + C(k,2);
    image_buffer = reshape(images_data((index) * 28 * 28 + 1 : (index + 1) * 28 * 28), 28, 28);
    subplot(10, 10, k);
    imshow(uint8(image_buffer)');    % 转置
    title(sprintf('%d->%d', C(k,3), C(k,4)));   % label -> output
end

4.可视化结果

结果如下所示,其中有些图片网络未能正确识别,还有些对于人眼来说都是模棱两可的,有点太难为机器了。。。

MNIST

猜你喜欢

转载自blog.csdn.net/tianrolin/article/details/53542050