Use Tensorboard to visualize the training process under the Pytorch framework

>>>This article refers to the official tutorial given by Pytorch to further refine it<<<

  • ---- The following is the detailed address used by Pytorch's official teaching Tensorboard. If you are interested, you can also refer to it.

Insert picture description here

  • ----I usually don’t know what functions you will use when using Tensorboard.Personally, I mainly use four of the functions

> ①Save the structure diagram of the network: In the GRAPHS of Tensorboard, there will be a structure diagram of the model, and each module of the entire model can be seen more clearly.

Insert picture description here

> ②Save training_loss, validation set acc and learning_rate changes in SCALARS of Tensorboard.

> ③Check the distribution of the weight value of each layer structure in the HISTOGRAMS of Tensorboard.

Insert picture description here

> ④Save some information of the predicted picture. In the IMAGES of Tensorboard, the prediction results of each step of some given pictures are saved

  • >>>The network model used in this article is ResNet. I won’t elaborate on its principles and how to build it. Next, let’s move on to the topic<<<

First enter the project, you need to create a folder, the pictures saved in it will be predicted during the training process and the results will be added to the tensorboard. In addition, you also need to prepare a label.txt file corresponding to the label of the picture, as shown in the following figure:

Insert picture description here

-> Download method of ResNet pre-training weights: import torchvision.models.resnet ctrl+click resnet in the statement to download the weights you want.But this experiment does not use pre-training weights, Because if you use the pre-training weights, you will find that acc and loss are basically unchanged, and the accuracy of the first epoch of training has reached 97%

Step-by-step instructions for tensorboard statements:

  • Instantiate the SummaryWriter object. The parameter is the folder where the tensorboard file is saved. After the statement is executed, the tensorboard file will be automatically created and saved.

 tb_writer = SummaryWriter(log_dir="runs/experiment") 
  • After instantiating the model, a zero matrix is ​​also created. Why do you want to create this zero matrix? : When adding a network structure diagram, it needs to be passed into the model for forward propagation. The network structure diagram is created according to the process of forward propagation of this matrix in the model, so as long as the size of the matrix and the picture are the same.

    # 实例化模型
    model = resnet34(num_classes=args.num_classes).to(device)
    # 将模型写入tensorboard
    init_img = torch.zeros((1, 3, 224, 224), device=device)#参数:模型输入tensor和使用训练的设备 
    tb_writer.add_graph(model, init_img)#利用模型和零矩阵创建网络的结构图
  • After each epoch, that is, after the validation set code is executed, the current epoch training set average loss, validation set acc, and learning_rate will be saved. -------------Note: the use of the tb_writer.add_scalar method: the first parameter is the label; the second parameter is the statistical data during the training process, the value here is not tensor, but floating-point data; the third parameter is the current training step. Use of the tb_writer.add_figure method: add the prediction result of the specified picture, draw it into a picture, and save it in tensorboard. The first parameter is the title of the drawn picture, the second parameter is the figure object, and the third parameter is the current training. step.

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)
        # update learning rate
        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)

        # add loss, acc and lr into tensorboard
        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["train_loss", "accuracy", "learning_rate"] 
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        # add figure into tensorboard
        fig = plot_class_preds(net=model,
                               images_dir="./plot_img",
                               transform=data_transform["val"],
                               num_plot=5,
                               device=device)
        if fig is not None:
            tb_writer.add_figure("predictions vs. actuals",
                                 figure=fig,
                                 global_step=epoch)

        # add conv1 weights into tensorboard
        tb_writer.add_histogram(tag="conv1",
                                values=model.conv1.weight,
                                global_step=epoch)
        tb_writer.add_histogram(tag="layer1/block0/conv1",
                                values=model.layer1[0].conv1.weight,
                                global_step=epoch)

  • Fig is formed by the plot_class_preds() function: its parameter transform corresponds to the image preprocessing used in the verification set, and parameter four is how many pictures to display:

plot_class_preds(net,
                     images_dir: str,
                     transform,
                     num_plot: int = 5,
                     device="cpu"):
    if not os.path.exists(images_dir):
        print("not found {} path, ignore add figure.".format(images_dir))
        return None

    label_path = os.path.join(images_dir, "label.txt")
    if not os.path.exists(label_path):
        print("not found {} file, ignore add figure".format(label_path))
        return None

    # read class_indict
    json_label_path = './class_indices.json'
    assert os.path.exists(json_label_path), "not found {}".format(json_label_path)
    json_file = open(json_label_path, 'r')
    # {"0": "daisy"}
    flower_class = json.load(json_file)
    # {"daisy": "0"}
    class_indices = dict((v, k) for k, v in flower_class.items())

    # reading label.txt file
    label_info = []
    with open(label_path, "r") as rd:
        for line in rd.readlines():
            line = line.strip()
            if len(line) > 0:
                split_info = [i for i in line.split(" ") if len(i) > 0]
                assert len(split_info) == 2, "label format error, expect file_name and class_name"
                image_name, class_name = split_info
                image_path = os.path.join(images_dir, image_name)
                # 如果文件不存在,则跳过
                if not os.path.exists(image_path):
                    print("not found {}, skip.".format(image_path))
                    continue
                # 如果读取的类别不在给定的类别内,则跳过
                if class_name not in class_indices.keys():
                    print("unrecognized category {}, skip".format(class_name))
                    continue
                label_info.append([image_path, class_name])

    if len(label_info) == 0:
        return None

    # get first num_plot info
    if len(label_info) > num_plot:
        label_info = label_info[:num_plot]

    num_imgs = len(label_info)
    images = []
    labels = []
    for img_path, class_name in label_info:
        # read img
        img = Image.open(img_path).convert("RGB")
        label_index = int(class_indices[class_name])

        # preprocessing
        img = transform(img)
        images.append(img)
        labels.append(label_index)

    # batching images
    images = torch.stack(images, dim=0).to(device)

    # inference
    with torch.no_grad():
        output = net(images)
        probs, preds = torch.max(torch.softmax(output, dim=1), dim=1)
        probs = probs.cpu().numpy()
        preds = preds.cpu().numpy()

    # width, height
    fig = plt.figure(figsize=(num_imgs * 2.5, 3), dpi=100)
    for i in range(num_imgs):
        # 1:子图共1行,num_imgs:子图共num_imgs列,当前绘制第i+1个子图
        ax = fig.add_subplot(1, num_imgs, i+1, xticks=[], yticks=[])

        # CHW -> HWC
        npimg = images[i].cpu().numpy().transpose(1, 2, 0)

        # 将图像还原至标准化之前
        # mean:[0.485, 0.456, 0.406], std:[0.229, 0.224, 0.225]
        npimg = (npimg * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
        plt.imshow(npimg.astype('uint8'))

        title = "{}, {:.2f}%\n(label: {})".format(
            flower_class[str(preds[i])],  # predict class
            probs[i] * 100,  # predict probability
            flower_class[str(labels[i])]  # true class
        )
        ax.set_title(title, color=("green" if preds[i] == labels[i] else "red"))

    return fig

Guess you like

Origin blog.csdn.net/qq_42308217/article/details/113761732