问题描述:
训练resnet模型,使用torchvision的ImageFolder来创建训练和测试集的DataLoader,然后训练模型的时候,出现RuntimeError: inconsistent tensor错误.
解决措施:
在transforms.Compose()中,transforms.Resize(224)后面跟一个transforms.CenterCrop(224)操作就OK!
完整代码:
def train_loader(path, batch_size=32, num_workers=4, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size = batch_size,
shuffle = True,
num_workers = num_workers,
pin_memory = pin_memory)