Dibujo en papel (1): uso de la biblioteca Matplotlib de Python

Tabla de contenido

 

1. Tratamiento de datos e instrucciones de aplicación

2. Código fuente y notas detalladas

3. Representación


1. Tratamiento de datos e instrucciones de aplicación

El formato de los tres archivos para los resultados de entrenamiento de mi experimento es el siguiente:

Durante el procesamiento, extraje los valores de pérdida de los tres entrenamientos de modelos y los puse en la lista correspondiente, y usé la biblioteca matplotlib para dibujar un gráfico de comparación de las tres curvas de época de pérdida de los tres modelos en el mismo lienzo.

Nota de aplicación: puede mejorar y utilizar el código fuente en la segunda sección de acuerdo con el formato de datos de sus propios archivos o la lista de datos experimentales que se han ordenado .


2. Código fuente y notas detalladas

import csv
import matplotlib.pyplot as plt
# import numpy as np
# 将数据提取,放入列表,用于绘图

# 模型1的损失函数值构成的列表
static_loss_list = []
# 模型2的损失函数值构成的列表
t_loss_list = []
# 模型3的损失函数值构成的列表
ts_loss_list = []

with open('C:/Users/zax/Desktop/curve_plotting/static_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            static_loss_list.append(data)
    f.close()

with open('C:/Users/zax/Desktop/curve_plotting/t_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            t_loss_list.append(data)
    f.close()

with open('C:/Users/zax/Desktop/curve_plotting/ts_train.txt', 'r', encoding='UTF-8') as f:
    reader = csv.reader(f)
    lines = [row for row in reader]
    for item in lines:
        if item[0].startswith("Loss"):
            data = item[0].split(":")[1].strip()[:9]
            data = float("%.2f" % float(data))
            ts_loss_list.append(data)
    f.close()
# print(ts_loss_list)


# 准备数据:图横纵坐标的数据和图在画布中的位置

# 方法一:使用numpy.arange()生成序列,但是当我们使用浮点参数时,可能会导致精度损失,这可能会导致不可预测的输出

# 我的实验迭代了500次,所以x(横轴)的取值范围为(0,500)
x = range(0, 500) # 在for in range中,默认步长为1,但range(0, 500)并非如此,0的下一个是无限趋于0的数

# 方法二:为了避免由于浮点精度而造成的任何精度损失,numpy.linspace()提供了一个单独的序列生成器,如果知道共需在二维坐标轴中生成几个点,这个方法是首选
# x = np.linspace(0, 500, 5)

# y轴是三个列表的损失数据集,每个列表元素500个
y1 = static_loss_list
y2 = t_loss_list
y3 = ts_loss_list

# 1、创建白色的画布(底板),并设置大小为9英寸×7英寸
plt.figure(figsize=(9, 7), facecolor='white')
# 如果在同一个画布下,画多个图(即子图),要采用subplot函数,其参数的含义:
# 将画布分为3行3列(从左到右,从上到下编号1-9),此次作的图显示在第6号位置
# plt.subplot(3, 3, 6)  # 画布上的一个子图,用这条命令可以创建若干并列子图

# 2、描绘线条的格式,包括颜色、粗细、样式等
plt.plot(x, y1, color="green", linewidth=1, linestyle="--", label="modelA")
plt.plot(x, y2, color="yellow", linewidth=1, linestyle="--", label="modelB")
plt.plot(x, y3, color="red", linewidth=1, linestyle="-", label="modelC")

# 以下三行代码无顺序限制

# legend 里边可以不写参数,默认按照三个实例对象(以上自动创建)的颜色和对应的名称进行描述
# 如果改变数组中元素的顺序,就会出现线条和名称不对应的灾难
plt.legend(['modelA', 'modelB', 'modelC'])
# 贴上x轴和y轴代表的含义
plt.xlabel('Training epoch')
plt.ylabel('Loss')

# 展示图片,也可以不展示
plt.show()

# 3、存储画好的图片到指定位置,dpi用于调整图片的分辨率
plt.savefig("C:/Users/zax/Desktop/curve_plotting/Loss-Epoch.jpg", dpi="72")

3. Representación

Supongo que te gusta

Origin blog.csdn.net/qq_40506723/article/details/127138529
Recomendado
Clasificación