Mnist-Datensatz als Beispiel
1. Trainieren Sie direkt am gesamten Datensatz
Datendownload und Vorverarbeitung
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset_train = torchvision.datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = torchvision.datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
Dann können Sie es in den Dataloader einfügen,
trainDataLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True) # 批量读取并打乱
testDataLoader = torch.utils.data.DataLoader(dataset=testData, batch_size=batch_size)
Lesen Sie Daten während des Trainings iterativ
for epoch in range(1, epochs + 1):
processBar = tqdm(trainDataLoader, unit='step')
model.train(True)
train_loss, train_correct = 0, 0
for step, (train_imgs, labels) in enumerate(processBar):
if torch.cuda.is_available(): # GPU可用
train_imgs = train_imgs.cuda()
labels = labels.cuda()
model.zero_grad() # 梯度清零
outputs = model(train_imgs) # 输入训练集
loss = criterion(outputs, labels) # 计算损失函数
predictions = torch.argmax(outputs, dim=1) # 得到预测值
correct = torch.sum(predictions == labels)
accuracy = correct / labels.shape[0] # 计算这一批次的正确率
loss.backward() # 反向传播
optimizer.step() # 更新优化器参数
processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f" % # 可视化训练进度条设置
(epoch, epochs, loss.item(), accuracy.item()))
2. Nehmen Sie einige Daten aus dem Datensatz für das Training
Trainingsdaten mithilfe des Datensatzindex extrahieren
Daten mit der Seriennummer data_idx abrufen
dataset_train[data_idx][0] #取图像数据(image)
dataset_train[data_idx][1] #取对应的标签(label)
Um einige Daten als Trainingssatz abzutasten, können Sie daher den folgenden Code verwenden
sample_index = [i for i in range(500)] #假设取前500个训练数据
X_train = []
y_train = []
for i in sample_index:
X = dataset_train[i][0]
X_train.append(X)
y = dataset_train[i][1]
y_train.append(y)
sampled_train_data = [(X, y) for X, y in zip(X_train, y_train)] #包装为数据对
trainDataLoader = torch.utils.data.DataLoader(sampled_train_data, batch_size=16, shuffle=True)
Bringen Sie trainDataloader einfach in den Trainingsprozess ein.
Verweise
[1] Erste Schritte mit PyTorch – Implementierung der MNIST-Klassifizierung