在学习《深度学习之PyTorch实战计算机视觉》这本书的"模型融合"部分的时候, 我遇到了一些问题.
对于for batch, data in enumerate(dataloader[phase], 1):
中的enumerate()
函数该如何理解.
通过查看numerate()
的文档:
enumerate(iterable, start=0)
"""Return an enumerate object.
iterable must be a sequence, an iterator, or some other object which supports iteration.
The __next__() method of the iterator returned by enumerate() returns a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over iterable."""
numerate()
会返回两个值, 一个是索引, 一个是数据numerate()
需要两个参数:- 第一个参数是可迭代的对象
- 第二个参数是起始位置, 数据类型为
int
seasons = ['Spring', 'Summer', 'Fall', 'Winter']
for idx, data in list(enumerate(seasons)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 0, data: Spring idx: 1, data: Summer idx: 2, data: Fall idx: 3, data: Winter
print("")
for idx, data in list(enumerate(seasons, 0)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 0, data: Spring idx: 1, data: Summer idx: 2, data: Fall idx: 3, data: Winter
print("")
for idx, data in list(enumerate(seasons, start=1)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 1, data: Spring idx: 2, data: Summer idx: 3, data: Fall idx: 4, data: Winter
print("")
for idx, data in list(enumerate(seasons, start=5)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 5, data: Spring idx: 6, data: Summer idx: 7, data: Fall idx: 8, data: Winter
结合以上的代码结果, 我们观察到一个规律:
- 默认情况下,
numerate()
的start
参数默认为0
, 即返回的第一个值默认是从0
开始的, 而且这个值只会影响返回索引的值, 对数据没有任何影响
所以这也就解释了在PyTorch训练部分, 一般start=1
而不是0
我们看一下训练的部分代码.
for batch, data in enumerate(dataloader[phase], 1):
x, y = data
if torch.cuda.is_available():
x, y = Variable(x.cuda()), Variable(y.cuda())
else:
x, y = Variable(x), Variable(y)
# 前向传播
y_pred_1 = model_1(x)
y_pred_2 = model_2(x)
blending_y_pred = y_pred_1 * weight_1 + y_pred_2 * weight_2
pred_1 = torch.max(y_pred_1.data, 1)[1]
pred_2 = torch.max(y_pred_2.data, 1)[1]
blending_pred = torch.max(blending_y_pred.data, 1)[1]
# 梯度清0
optimizer_1.zero_grad()
optimizer_2.zero_grad()
# 计算损失
loss_1 = loss_Fn_1(y_pred_1, y)
loss_2 = loss_Fn_2(y_pred_2, y)
"""
先判断是在训练还是在验证:
如果在训练则开始进行计算反向传播, 并更新梯度
如果在验证则开始不进行计算反向传播, 不更新梯度
"""
if phase == "train":
# 反向传播
loss_1.backward()
loss_2.backward()
# 梯度更新
optimizer_1.step()
optimizer_2.step()
running_loss_1 += loss_1.item()
running_corrects_1 += torch.sum(pred_1 == y.data)
running_loss_2 += loss_2.item()
running_corrects_2 += torch.sum(pred_2 == y.data)
blending_running_corrects += torch.sum(blending_pred == y.data)
if batch % 500 == 0 and phase == "train":
print("Batch {}:\n "
"--------------------------------------------------------------------\n"
"Model_1 Train Loss:{:.4f}, "
"Model_1 Train Acc:{:.4f}\n"
"Model_2 Train Loss:{:.4f}, "
"Model_2 Train Acc:{:.4f}\n "
"--------------------------------------------------------------------\n"
"Blending_Model Acc:{:.4f}%".format(batch,
running_loss_1 / batch,
100 * running_corrects_1 / (16 * batch),
running_loss_2 / batch,
100 * running_corrects_2 / (16 * batch),
100 * blending_running_corrects / (16 * batch)
))
epoch_loss_1 = running_loss_1 * 16 / len(image_datasets[phase])
epoch_acc_1 = 100 * running_corrects_1 / len(image_datasets[phase])
epoch_loss_2 = running_loss_2 * 16 / len(image_datasets[phase])
epoch_acc_2 = 100 * running_corrects_2 / len(image_datasets[phase])
epoch_blending_acc = 100 * blending_running_corrects / len(image_datasets[phase])
print("Model_1 Loss:{:.4f}, Model_1 Acc:{:.4f}%\n "
"Model_2 Loss:{:.4f}, Model_2 Acc:{:.4f}%\n "
"Blending_Model Acc:{:.4f}%".format(epoch_loss_1,
epoch_acc_1,
epoch_loss_2,
epoch_acc_2,
epoch_blending_acc
))
time_end = time.time()
print("Total Time is:{}".format(time_end - time_start))
通过阅读代码我们发现, for batch, data in enumerate(dataloader[phase], 1):
将返回的索引给了batch
, 而batch
在之后打印输出训练结果的时候有用到.
即if batch % 500 == 0 and phase == "train":
, 我们设计enumerate(data, start=1)
是为了让batch
从1
开始, 因为 b a t c h ≥ 1 batch \ge 1 batch≥1).