当基于LeNet网络的MNIST手写库训练完毕后,测试样本精度能达到99%以上。但是那错误的不到百分之一的样本是什么样子的呢?我们怎么才能把这些识别错误的样本可视化出来呢?
1.将测试错误样本打印出来
当运行测试时,最后的输出层为AccuracyLayer层。AccuracyLayer对前一层全连接层ip2的10个神经元输出结果进行排序,然后将最大值所对应的神经元序号与标签label进行比较,相等则判定预测正确;否则判定预测错误。所以,我们在这里添加输出判定错误的代码
void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
...
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second == label_value) {
// 预测正确
...
}
else
{
// 预测错误
// index为batch中的图片序号(0~99),label为标签值,output为预测值
LOG(INFO) << "index:" << i << " label:" << label_value << " output:" << bottom_data_vector[k].second;
}
}
}
这样我们就知道在一个batch中哪些图片被预测错误,以及它的标签值和预测值。测试样本总共有10000个,分为100个batch,每个batch大小为100个,所以我们还需要输出每个batch的序号。跳转到Slover::Test()函数中,
void Solver<Dtype>::Test(const int test_net_id) {
...
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
// 输出batch序号
LOG(INFO) << "batch:" << i;
}
}
2.将日志输出至文件
Caffe框架使用google开源日志glog,用来在控制台输出程序运行日志。其实日志不仅输出在控制台,glog还帮我们在硬盘临时文件中也保存了一份日志。它是如何做到的呢?在Caffe框架初始化时调用了初始化函数GlobalInit(),
void GlobalInit(int* pargc, char*** pargv) {
...
// Google logging.
::google::InitGoogleLogging(*(pargv)[0]);
...
}
其中的::google::InitGoogleLogging()函数就是定位硬盘文件输出日志,参数*(pargv)[0]为caffe.exe的路径名,但是输出目录为系统的临时目录。你可以在临时目录下找到类似这样随机名字的文件:
caffe.exe.ComputeName.UserName.log.INFO.Data-194350.18356
这个文件可以直接用文本编辑器打开查看,内容和在控制台输出的信息一模一样。
3.用Matlab将错误样本可视化
下面我们来写段Matlab代码,用来读取上面的日志文件,以及将MNIST数据库可视化。
clear;clc;close all;
fid = fopen('caffe.exe.txt'); % 替换为日志文件名
tline = fgetl(fid);
C = []; % 定义空矩阵用来存放结果
while ischar(tline)
if ~isempty(strfind(tline, 'batch:')) % 查找字符串
indexline = fgetl(fid);
if ~isempty(strfind(indexline, 'batch:'))
tline = indexline;
elseif isempty(strfind(indexline, 'index:'))
tline = indexline;
else
% 在tline中解析batch
idx1 = strfind(tline, 'batch:');
batch = str2num(tline(idx1 + 6 : length(tline)));
% 在indexline中解析index,label,output
idx2 = strfind(indexline, 'index:');
idx3 = strfind(indexline, 'label:');
idx4 = strfind(indexline, 'output:');
index = str2num(indexline(idx2 + 6 : idx3 - 2));
label = str2num(indexline(idx3 + 6 : idx4 - 2));
output = str2num(indexline(idx4 + 7 : length(indexline)));
% 添加到数组中
C = [C; batch, index, label, output];
end
else
tline = fgetl(fid);
end
end
fclose(fid);
% 可视化部分
image_file_name = 't10k-images.idx3-ubyte';
fid = fopen(image_file_name);
images_data = fread(fid, 'uint8');
fclose(fid);
images_data = images_data(17:end);
image_buffer = zeros(28, 28);
for k = 1:1:size(C,1)
figure(size(C,1));
index = C(k,1) * 100 + C(k,2);
image_buffer = reshape(images_data((index) * 28 * 28 + 1 : (index + 1) * 28 * 28), 28, 28);
subplot(10, 10, k);
imshow(uint8(image_buffer)'); % 转置
title(sprintf('%d->%d', C(k,3), C(k,4))); % label -> output
end
4.可视化结果
结果如下所示,其中有些图片网络未能正确识别,还有些对于人眼来说都是模棱两可的,有点太难为机器了。。。