Ausführliche Erklärung und Beispielanalyse von WeightedRandomSampler

Anwendungsszenarien:

我们常常会遇到--数据集不平衡--问题,举个最简单的例子:猫狗分类,你有10000张狗的图片,却只有1000张猫的图片。这时如果直接利用整个数据集训练就很容易导致网络对‘狗’这个类别过拟合,在猫的识别任务上表现很差。

Lösung: Verwenden Sie den WeightedRandomSampler von Pytorch . Jede Epoche wird entsprechend der Anzahl jeder Kategorie gewichtet und Stichproben jeder Kategorie werden abgetastet.

Beispiel: Jede Epoche verwendet 1.600 Bilder, um das Netzwerk zu trainieren. Darunter ist die Wahrscheinlichkeit, dass Hundeproben erfasst werden, gering, aber die Anzahl ist groß. Katzenproben haben eine hohe Wahrscheinlichkeit, erfasst zu werden, aber die Anzahl ist gering. Das macht die Anzahl der zwei Kategorien in den 1.600 Trainingsbildern. Sie ist fast gleich und jede Epoche wird abgetastet. Wenn das Training ausreicht, werden grundsätzlich alle Bilder trainiert.


Offizielle Einführung in die Verwendung

先看下官网的API
Fügen Sie hier eine Bildbeschreibung ein
Gewicht : eine Liste, jede Zahl misst das Wahrscheinlichkeitsgewicht der Stichprobe, die sich am Index befindet, der abgetastet wird (die Summe muss nicht 1 sein) num_samples: eine ganze Zahl, die die
Zahl angibt, die Sie abtasten möchten, wie z. B. die obige Katze und der Hund Klassifizierung in jeder Epoche in den gesamten Daten. 1600 Bilder zentral abtasten. Ersetzung
: eine Bool-Variable, ob die Abtastung wiederholt werden soll. Wenn „True“ ist, kann ein Bild mehrmals abgetastet werden.
Generator : Nutzlos, mach dir keine Sorgen.


Anwendungsbeispiele:

猫狗分类,10000张狗图片,1000张猫图片。要实现:每个epoch只取1600张图片训练,且不重复采样

Schlüsselcode:

from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

# 自定义的dataset,前10000个是dog,然后是1000个cat,共11000个数据。
train_dataset   = DataGenerator(train_imgs, input_shape=input_shape, train=True)
# 每个类别的样本数量
num_dog   	    = 10000
num_cat         = 1000
# weight列表
train_weights   = []
train_weights.extend([1/num_dog]*num_dog) # 扩展10000个0.0001,前10000个样本的采样权重,狗多权重小
train_weights.extend([1/num_cat]*num_cat)  # 扩展1000个0.001,后1000个样本的采样权重,猫少权重大
print(train_weights) # 可以打印看看,总共11000个数,前10000个是0.0001,然后是1000个0.001
# 创建WeightedRandomSampler,1600为采样数
train_sampler   = WeightedRandomSampler(train_weights, 1600, replacement=False)
# 将sampler传给Dataloader,不再需要shuffle
gen             = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=8)

Erläuterung: Die Gewichtsliste enthält die Stichprobenwahrscheinlichkeit jeder Probe. Es gibt 11.000 Gewichte für 11.000 Proben. Im Beispiel addieren sich die Gewichte von Hunden zu 1 und die Gewichte von Katzen zu 1. Die Summe der Gewichte von Katzen und Hunde sind gleich, wodurch sichergestellt wird, dass die Anzahl der Bildproben der beiden Typen unter den 1600 Proben in jeder Epoche nahezu gleich ist. Dataloader verwendet zunächst einen Sampler, um jeweils 1.600 Bilder aus dem Datensatz abzutasten, und teilt sie dann für das Training in Stapelstapel auf.

Ich fürchte, das wird nicht jeder verstehen, also verzeihen Sie mir bitte den ganzen Unsinn. Wenn Sie Fragen haben, teilen Sie diese bitte im Kommentarbereich mit.

Ich denke du magst

Origin blog.csdn.net/ittongyuan/article/details/131088803
Empfohlen
Rangfolge