pytorch易忘

1、自定义数据集

class Dataset(torch.utils.data.Dataset):
def __init__(self):
super(Dataset, self).__init__()

def __len__(self):
return len()

def __getitem__(self, item):
return item

data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle)

2、数据、模型需要放在cuda上

3、损失函数

torch.nn

4、优化器

torch.optim

5、

6、



猜你喜欢

转载自www.cnblogs.com/liujianing/p/12660564.html