import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
def __init__(self, root):
# 临时变量不用加self
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
label = 1 if dog else 0
pil_img = Image.open(self,imgs[index])
array = np.asarray(pil_img)
data = t.from_numpy(array)
return data, label
def __len__(self):
return len(self.imgs)