pytorch训练模型遇到RuntimeError: inconsistent tensor错误

问题描述:

训练resnet模型,使用torchvision的ImageFolder来创建训练和测试集的DataLoader,然后训练模型的时候,出现RuntimeError: inconsistent tensor错误.

解决措施:

在transforms.Compose()中,transforms.Resize(224)后面跟一个transforms.CenterCrop(224)操作就OK!

完整代码:

def train_loader(path, batch_size=32, num_workers=4, pin_memory=True):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data.DataLoader(
        datasets.ImageFolder(path,
                            transforms.Compose([
                                transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                normalize,
                                ])),
        batch_size = batch_size,
        shuffle = True,
        num_workers = num_workers,
        pin_memory = pin_memory)

猜你喜欢

转载自blog.csdn.net/tsq292978891/article/details/80152499