3. PyTorch中Tensorboard的使用(训练过程可视化)

  • 安装Tensorboard
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
  • SummaryWriter

首先定义一个SummaryWriter类的实例

writer = SummaryWriter("logs")

使用 C t r l + Ctrl + 鼠标左键 点击SummaryWriter,可以看到该类详细信息如下:

class SummaryWriter(object):
    """Writes entries directly to event files in the log_dir to be
    consumed by TensorBoard.

    The `SummaryWriter` class provides a high-level API to create an event file
    in a given directory and add summaries and events to it. The class updates the
    file contents asynchronously. This allows a training program to call methods
    to add data to the file directly from the training loop, without slowing down
    training.
    """
    .
    .
    .
    ...

当然,这些并不重要,我们会使用就可以了

  • writer.add_scalar()方法
    同样使用刚才的方法查看详细内容
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        """Add scalar data to summary.

        Args:
            tag (string): Data identifier
            scalar_value (float or string/blobname): Value to save
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
              with seconds after epoch of event

        Examples::

            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter()
            x = range(100)
            for i in x:
                writer.add_scalar('y=2x', i * 2, i)
            writer.close()

        Expected result:

        .. image:: _static/img/tensorboard/add_scalar.png
           :scale: 50 %

        """
        

第一个参数可以简单理解为保存图的名称,第二个参数是可以理解为Y轴数据,第三个参数可以理解为X轴数据。

简单写一个demo

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")


for i in range(100):
    # 不改事件名字就会合并到一个里面叠加
    writer.add_scalar("y=2x", 2 * i, i)

writer.close()

在命令行中输入以下命令来打开Tensorboard,–logdir为上面定义的路径(参见SummaryWriter类的定义),–port为端口号,默认为6006端口。

tensorboard --logdir=logs --port=6007

得到效果图:
image.png

  • 加载图片demo
writer.add_image()

具体详见该函数完整定义,传入图片必须为 t o r c h . T e n s o r , n u m p y . a r r a y , s t r i n g / b l o b n a m e torch.Tensor, numpy.array, string/blobname 中的一种。
如果使用PIL中的Image库加载图片,不符合要求,于是使用numpy进行图片的格式转换,转换的图片是HWC形式的数组,所以需要告诉add_image这个函数我们传入的形式是HWC(默认为CHW)。
完整demo如下:

from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np

writer = SummaryWriter("logs")
image_path = "dataset/train/ants_image/0013035.jpg"
img = Image.open(image_path)
# print(type(img))
img_array = np.array(img)
print(type(img_array))
print(img_array.shape)
# 得到(1,1,512)即(高度,宽度,通道)(H,W,C)不符合add_image要求
# 从PIL到numpy,需要在add_image()中指定每一个数字/维表示的含义

writer.add_image("test", img_array, 1, dataformats='HWC') # 转换的图片是HWC形式的数组,所以需要告诉add_image这个函数我们传入的形式是HWC(默认为CHW)。
writer.add_scalar()

writer.close()

(等待后续修改)

发布了13 篇原创文章 · 获赞 0 · 访问量 90

猜你喜欢

转载自blog.csdn.net/qq_35283167/article/details/104639798