[PyTorch][分类][模型融合]for batch, data in enumerate(dataloader[phase], 1)为什么从1开始

在学习《深度学习之PyTorch实战计算机视觉》这本书的"模型融合"部分的时候, 我遇到了一些问题.

对于for batch, data in enumerate(dataloader[phase], 1):中的enumerate()函数该如何理解.

通过查看numerate()的文档:

enumerate(iterable, start=0)
"""Return an enumerate object. 
iterable must be a sequence, an iterator, or some other object which supports iteration. 
The __next__() method of the iterator returned by enumerate() returns a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over iterable."""
  • numerate()会返回两个值, 一个是索引, 一个是数据
  • numerate()需要两个参数:
    1. 第一个参数是可迭代的对象
    2. 第二个参数是起始位置, 数据类型为int
seasons = ['Spring', 'Summer', 'Fall', 'Winter']

for idx, data in list(enumerate(seasons)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 0, data: Spring    idx: 1, data: Summer    idx: 2, data: Fall    idx: 3, data: Winter
print("")

for idx, data in list(enumerate(seasons, 0)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 0, data: Spring    idx: 1, data: Summer    idx: 2, data: Fall    idx: 3, data: Winter
print("")

for idx, data in list(enumerate(seasons, start=1)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 1, data: Spring    idx: 2, data: Summer    idx: 3, data: Fall    idx: 4, data: Winter
print("")

for idx, data in list(enumerate(seasons, start=5)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 5, data: Spring    idx: 6, data: Summer    idx: 7, data: Fall    idx: 8, data: Winter

结合以上的代码结果, 我们观察到一个规律:

  • 默认情况下, numerate()start参数默认为0, 即返回的第一个值默认是从0开始的, 而且这个值只会影响返回索引的值, 对数据没有任何影响

所以这也就解释了在PyTorch训练部分, 一般start=1而不是0

我们看一下训练的部分代码.

for batch, data in enumerate(dataloader[phase], 1):
    x, y = data
    if torch.cuda.is_available():
        x, y = Variable(x.cuda()), Variable(y.cuda())
    else:
        x, y = Variable(x), Variable(y)

    # 前向传播
    y_pred_1 = model_1(x)
    y_pred_2 = model_2(x)
    blending_y_pred = y_pred_1 * weight_1 + y_pred_2 * weight_2

    pred_1 = torch.max(y_pred_1.data, 1)[1]
    pred_2 = torch.max(y_pred_2.data, 1)[1]
    blending_pred = torch.max(blending_y_pred.data, 1)[1]

    # 梯度清0
    optimizer_1.zero_grad()
    optimizer_2.zero_grad()

    # 计算损失
    loss_1 = loss_Fn_1(y_pred_1, y)
    loss_2 = loss_Fn_2(y_pred_2, y)

    """
    先判断是在训练还是在验证: 
        如果在训练则开始进行计算反向传播, 并更新梯度
        如果在验证则开始不进行计算反向传播, 不更新梯度
    """
    if phase == "train":
        # 反向传播
        loss_1.backward()
        loss_2.backward()

        # 梯度更新
        optimizer_1.step()
        optimizer_2.step()

    running_loss_1 += loss_1.item()
    running_corrects_1 += torch.sum(pred_1 == y.data)
    running_loss_2 += loss_2.item()
    running_corrects_2 += torch.sum(pred_2 == y.data)
    blending_running_corrects += torch.sum(blending_pred == y.data)

    if batch % 500 == 0 and phase == "train":
        print("Batch {}:\n "
              "--------------------------------------------------------------------\n"
              "Model_1 Train Loss:{:.4f}, "
              "Model_1 Train Acc:{:.4f}\n"
              "Model_2 Train Loss:{:.4f}, "
              "Model_2 Train Acc:{:.4f}\n "
              "--------------------------------------------------------------------\n"
              "Blending_Model Acc:{:.4f}%".format(batch,
                                                 running_loss_1 / batch,
                                                 100 * running_corrects_1 / (16 * batch),
                                                 running_loss_2 / batch,
                                                 100 * running_corrects_2 / (16 * batch),
                                                 100 * blending_running_corrects / (16 * batch)
                                                 ))

epoch_loss_1 = running_loss_1 * 16 / len(image_datasets[phase])
epoch_acc_1 = 100 * running_corrects_1 / len(image_datasets[phase])
epoch_loss_2 = running_loss_2 * 16 / len(image_datasets[phase])
epoch_acc_2 = 100 * running_corrects_2 / len(image_datasets[phase])
epoch_blending_acc = 100 * blending_running_corrects / len(image_datasets[phase])

print("Model_1 Loss:{:.4f}, Model_1 Acc:{:.4f}%\n "
      "Model_2 Loss:{:.4f}, Model_2 Acc:{:.4f}%\n "
      "Blending_Model Acc:{:.4f}%".format(epoch_loss_1,
                                          epoch_acc_1,
                                          epoch_loss_2,
                                          epoch_acc_2,
                                          epoch_blending_acc
                                          ))

time_end = time.time()
print("Total Time is:{}".format(time_end - time_start))

通过阅读代码我们发现, for batch, data in enumerate(dataloader[phase], 1):将返回的索引给了batch, 而batch在之后打印输出训练结果的时候有用到.

if batch % 500 == 0 and phase == "train":, 我们设计enumerate(data, start=1)是为了让batch1开始, 因为 b a t c h ≥ 1 batch \ge 1 batch1).

猜你喜欢

转载自blog.csdn.net/weixin_44878336/article/details/117666416
今日推荐