pytorch每次迭代训练前都重新对数据集进行采样形成平衡数据集

对于不平衡数据集的训练通常有两种方法:

  • 一种是先用数据平衡的方法形成平衡数据集之后用于每一轮的训练,此时每轮训练的数据集是不变的,这一方法在pytorch的实现比较简单,即先构建好平衡数据集train_set,然后构建train_loder:
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

       这种方法只需要构建一次的train_loder

  • 还有一种方法稍微麻烦一些,就是在每轮的迭代训练前都重新对数据集进行随机采样形成平衡数据集train_set,此时每轮训练的数据集是变化的,需要在每轮的epoch中重新构建train_loder:
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

       训练有多少个epoch,就需要构建train_loder多少次

猜你喜欢

转载自blog.csdn.net/weixin_38314865/article/details/107686697