利用torch.utils.data.Dataset自定义数据加载类

import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
  def __init__(self, root):

    # 临时变量不用加self
    imgs = os.listdir(root)
    self.imgs = [os.path.join(root, img) for img in imgs]

  def __getitem__(self, index):
    label = 1 if dog else 0
    pil_img = Image.open(self,imgs[index])
    array = np.asarray(pil_img)
    data = t.from_numpy(array)
    return data, label

  def __len__(self):
    return len(self.imgs)

猜你喜欢

转载自www.cnblogs.com/liujianing/p/12320539.html