论文的绘图(一):使用 Python 的 Matplotlib 库

目录

 

一、数据处理和应用说明

二、源代码及详细注解

三、效果图


一、数据处理和应用说明

我的实验的训练结果的三个文件的格式如下所示:

在处理时,我把三个模型训练的损失值分别提取,放在对应的列表中,使用 matplotlib 库在同一个画布中画出三个模型的三条 Loss-epoch 曲线的对比图。

应用说明:大家可以根据自己文件的数据格式或已整理的实验数据列表作处理,对第二节的源代码改进和使用


二、源代码及详细注解

import csv
import matplotlib.pyplot as plt
# import numpy as np
# 将数据提取,放入列表,用于绘图

# 模型1的损失函数值构成的列表
static_loss_list = []
# 模型2的损失函数值构成的列表
t_loss_list = []
# 模型3的损失函数值构成的列表
ts_loss_list = []

with open('C:/Users/zax/Desktop/curve_plotting/static_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            static_loss_list.append(data)
    f.close()

with open('C:/Users/zax/Desktop/curve_plotting/t_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            t_loss_list.append(data)
    f.close()

with open('C:/Users/zax/Desktop/curve_plotting/ts_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            ts_loss_list.append(data)
    f.close()
# print(ts_loss_list)


# 准备数据:图横纵坐标的数据和图在画布中的位置

# 方法一:使用numpy.arange()生成序列,但是当我们使用浮点参数时,可能会导致精度损失,这可能会导致不可预测的输出

# 我的实验迭代了500次,所以x(横轴)的取值范围为(0,500)
x = range(0, 500) # 在for in range中,默认步长为1,但range(0, 500)并非如此,0的下一个是无限趋于0的数

# 方法二:为了避免由于浮点精度而造成的任何精度损失,numpy.linspace()提供了一个单独的序列生成器,如果知道共需在二维坐标轴中生成几个点,这个方法是首选
# x = np.linspace(0, 500, 5)

# y轴是三个列表的损失数据集,每个列表元素500个
y1 = static_loss_list
y2 = t_loss_list
y3 = ts_loss_list

# 1、创建白色的画布(底板),并设置大小为9英寸×7英寸
plt.figure(figsize=(9, 7), facecolor='white')
# 如果在同一个画布下,画多个图(即子图),要采用subplot函数,其参数的含义:
# 将画布分为3行3列(从左到右,从上到下编号1-9),此次作的图显示在第6号位置
# plt.subplot(3, 3, 6)  # 画布上的一个子图,用这条命令可以创建若干并列子图

# 2、描绘线条的格式,包括颜色、粗细、样式等
plt.plot(x, y1, color="green", linewidth=1, linestyle="--", label="modelA")
plt.plot(x, y2, color="yellow", linewidth=1, linestyle="--", label="modelB")
plt.plot(x, y3, color="red", linewidth=1, linestyle="-", label="modelC")

# 以下三行代码无顺序限制

# legend 里边可以不写参数,默认按照三个实例对象(以上自动创建)的颜色和对应的名称进行描述
# 如果改变数组中元素的顺序,就会出现线条和名称不对应的灾难
plt.legend(['modelA', 'modelB', 'modelC'])
# 贴上x轴和y轴代表的含义
plt.xlabel('Training epoch')
plt.ylabel('Loss')

# 展示图片,也可以不展示
plt.show()

# 3、存储画好的图片到指定位置,dpi用于调整图片的分辨率
plt.savefig("C:/Users/zax/Desktop/curve_plotting/Loss-Epoch.jpg", dpi="72")

三、效果图

猜你喜欢

转载自blog.csdn.net/qq_40506723/article/details/127138529