Ubuntu+python3系统下使用caffe自带工具绘制loss和accuracy线

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/betty13006159467/article/details/88665054

最近在做vgg网络的微调,想看一下微调的效果,所以就要把loss和accuracy线绘制出来。
caffe自带了绘制loss和accuracy线的工具,在 caffe/tools/extra文件下的parse_log.sh、extract_seconds.py和plot_training_log.py.example
由于这个工具适用的是python2.7,所以在python3下会出现一些小问题,在此记录一下解决过程。
1、记录微调过程日志
要想将训练过程的数据更好的展示出来,首先要记录微调过程日志,在这里我直接写了一个脚本finetune_vgg.sh

#!/usr/bin/env sh
set -en

sudo ./build/tools/caffe train \
	--solver models/finetune_vgg16/solver.prototxt \
	--weights models/VGG/VGG_ILSVRC_16_layers.caffemodel > /caffe/models/finetune_vgg16/log/vgg16_finetune.log 2>&1

记得将目录改成自己的目录,这里我是在caffe/models下新建了一个finetune_vgg16文件夹,在这个文件夹又建立了一个log文件夹用来存放生成的log文件。
2、在caffe文件下运行finetune_vgg.sh,微调训练后,在log文件夹下就生成了一个vgg16_finetune.log日志文件
3、微调完成后,我直接将vgg16_finetune.log日志文件拷贝到caffe/tools/extra文件夹下,然后在tools/extra文件夹下运行下面这句代码

./plot_training_log.py.example 0 VGG16_finetune.png VGG16_finetune.log

但是会出错!
这是因为caffe的这个自带工具适用于python2.7,而我的python是3.6版本的,我们都知道,python2和python3有很多语法都不一样!!!
真的是坑!
接下来就是对plot_training_log.py.example进行修改
(1)、print问题
python3的print语法是要带圆括号的,所以文档中所有的print后都要改为圆括号
(2)将xrange改为range,注意有三处需要修改
(3)dict.keys()问题

在python3中dict.keys()返回的是1个dict_keys对象,所以这里要将makers.keys()[idx]改为list(markers.keys())[idx],修改位置在random_marker()函数中
修改过的plot_training_log.py.example代码如下所示:

#!/usr/bin/env python
import inspect
import os
import random
import sys
import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.legend as lgd
import matplotlib.markers as mks

def get_log_parsing_script():
    dirname = os.path.dirname(os.path.abspath(inspect.getfile(
        inspect.currentframe())))
    return dirname + '/parse_log.sh'

def get_log_file_suffix():
    return '.log'

def get_chart_type_description_separator():
    return '  vs. '

def is_x_axis_field(field):
    x_axis_fields = ['Iters', 'Seconds']
    return field in x_axis_fields

def create_field_index():
    train_key = 'Train'
    test_key = 'Test'
    field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' loss':2,
                              train_key + ' learning rate':3},
                   test_key:{'Iters':0, 'Seconds':1, test_key + ' accuracy':2,
                             test_key + ' loss':3}}
    fields = set()
    for data_file_type in field_index.keys():
        fields = fields.union(set(field_index[data_file_type].keys()))
    fields = list(fields)
    fields.sort()
    return field_index, fields

def get_supported_chart_types():
    field_index, fields = create_field_index()
    num_fields = len(fields)
    supported_chart_types = []
    for i in range(num_fields):
        if not is_x_axis_field(fields[i]):
            for j in range(num_fields):
                if i != j and is_x_axis_field(fields[j]):
                    supported_chart_types.append('%s%s%s' % (
                        fields[i], get_chart_type_description_separator(),
                        fields[j]))
    return supported_chart_types

def get_chart_type_description(chart_type):
    supported_chart_types = get_supported_chart_types()
    chart_type_description = supported_chart_types[chart_type]
    return chart_type_description

def get_data_file_type(chart_type):
    description = get_chart_type_description(chart_type)
    data_file_type = description.split()[0]
    return data_file_type

def get_data_file(chart_type, path_to_log):
    return (os.path.basename(path_to_log) + '.' +
            get_data_file_type(chart_type).lower())

def get_field_descriptions(chart_type):
    description = get_chart_type_description(chart_type).split(
        get_chart_type_description_separator())
    y_axis_field = description[0]
    x_axis_field = description[1]
    return x_axis_field, y_axis_field

def get_field_indices(x_axis_field, y_axis_field):
    data_file_type = get_data_file_type(chart_type)
    fields = create_field_index()[0][data_file_type]
    return fields[x_axis_field], fields[y_axis_field]

def load_data(data_file, field_idx0, field_idx1):
    data = [[], []]
    with open(data_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line[0] != '#':
                fields = line.split()
                data[0].append(float(fields[field_idx0].strip()))
                data[1].append(float(fields[field_idx1].strip()))
    return data

def random_marker():
    markers = mks.MarkerStyle.markers
    num = len(markers.keys())
    idx = random.randint(0, num - 1)
    #return markers.keys()[idx]
    return list(markers.keys())[idx]

def get_data_label(path_to_log):
    label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind(
        get_log_file_suffix())]
    return label

def get_legend_loc(chart_type):
    x_axis, y_axis = get_field_descriptions(chart_type)
    loc = 'lower right'
    if y_axis.find('accuracy') != -1:
        pass
    if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1:
        loc = 'upper right'
    return loc

def plot_chart(chart_type, path_to_png, path_to_log_list):
    for path_to_log in path_to_log_list:
        os.system('%s %s' % (get_log_parsing_script(), path_to_log))
        data_file = get_data_file(chart_type, path_to_log)
        x_axis_field, y_axis_field = get_field_descriptions(chart_type)
        x, y = get_field_indices(x_axis_field, y_axis_field)
        data = load_data(data_file, x, y)
        ## TODO: more systematic color cycle for lines
        color = [random.random(), random.random(), random.random()]
        label = get_data_label(path_to_log)
        linewidth = 0.75
        ## If there too many datapoints, do not use marker.
##        use_marker = False
        use_marker = True
        if not use_marker:
            plt.plot(data[0], data[1], label = label, color = color,
                     linewidth = linewidth)
        else:
            marker = random_marker()
            plt.plot(data[0], data[1], label = label, color = color,
                     marker = marker, linewidth = linewidth)
    legend_loc = get_legend_loc(chart_type)
    plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
    plt.title(get_chart_type_description(chart_type))
    plt.xlabel(x_axis_field)
    plt.ylabel(y_axis_field)
    plt.savefig(path_to_png)
    plt.show()

def print_help():
    print ("""This script mainly serves as the basis of your customizations.
Customization is a must.
You can copy, paste, edit them in whatever way you want.
Be warned that the fields in the training log may change in the future.
You had better check the data files and change the mapping from field name to
 field index in create_field_index before designing your own plots.
Usage:
    ./plot_training_log.py chart_type[0-%s] /where/to/save.png /path/to/first.log ...
Notes:
    1. Supporting multiple logs.
    2. Log file name must end with the lower-cased "%s".
Supported chart types:""" % (len(get_supported_chart_types()) - 1,
                             get_log_file_suffix()))
    supported_chart_types = get_supported_chart_types()
    num = len(supported_chart_types)
    for i in range(num):
        print ('    %d: %s' % (i, supported_chart_types[i]))
    sys.exit()

def is_valid_chart_type(chart_type):
    return chart_type >= 0 and chart_type < len(get_supported_chart_types())

if __name__ == '__main__':
    if len(sys.argv) < 4:
        print_help()
    else:
        chart_type = int(sys.argv[1])
        if not is_valid_chart_type(chart_type):
            print ('%s is not a valid chart type.' % chart_type)
            print_help()
        path_to_png = sys.argv[2]
        if not path_to_png.endswith('.png'):
            print ('Path must ends with png' % path_to_png)
            sys.exit()
        path_to_logs = sys.argv[3:]
        for path_to_log in path_to_logs:
            if not os.path.exists(path_to_log):
                print ('Path does not exist: %s' % path_to_log)
                sys.exit()
            if not path_to_log.endswith(get_log_file_suffix()):
                print ('Log file must end in %s.' % get_log_file_suffix())
                print_help()
        ## plot_chart accpets multiple path_to_logs
        plot_chart(chart_type, path_to_png, path_to_logs)

这样就解决了我的问题,重新绘制loss和accuracy线,就成功了!

猜你喜欢

转载自blog.csdn.net/betty13006159467/article/details/88665054
今日推荐