The basic structure of the data set
You can refer to the official documentation Web documantation . There are three main categories: Dataset, Sampler and DataLoader.
-
Dataset:
the representative dataset abstract class; all other data sets should inherit it. All subclasses should override len (providing data set size) and getItem (support ranges from 0 to len (Self) shaping index). -
Sampler:
All sampler reference class; each subclass must provide a sampler iter method, a method of indexing a set of elements through the data, and returns an iterator length len method. -
DataLoader:
a combination of data sets and samplers, and provide a single process or processes on the data iterator set.
Simple dataset class:
train_images_path = "./data/train_images"
train_labels_path = "./data/train_labels"
class RSDataset(Dataset):
def __init__(self, input_root, mode="train", debug = False):
super().__init__()
self.input_root = input_root
self.mode = mode
if debug == False:
self.input_ids = sorted(img for img in os.listdir(self.input_root))
else:
self.input_ids = sorted(img for img in os.listdir(self.input_root))[:500]
self.mask_transform = transforms.Compose([
transforms.Lambda(to_monochrome),
transforms.Lambda(to_tensor),
])
self.image_transform = transforms.Compose([
transforms.ToTensor(),
])
self.transform = DualCompose([
RandomFlip(),
RandomRotate90(),
Rotate(),
Shift(),
])
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
# at this point all transformations are applied and we expect to work with raw tensors
imageName = os.path.join(self.input_root,self.input_ids[idx])
image = np.array(cv2.imread(imageName), dtype=np.float32)
mask = np.array(cv2.imread(imageName.replace("train_images", "train_labels")))/255
h, w, c = image.shape
mask1 = np.zeros((h, w), dtype=int)
if self.mode == "train":
image, mask = self.transform(image, mask)
mask1 = mask[:,:,0]
return self.image_transform(image), self.mask_transform(mask1)
else:
mask1 = mask[:,:,0]
return self.image_transform(image), self.mask_transform(mask1)
###划分训练集和验证集
def build_loader(input_img_folder = "./data/train_images",
batch_size = 16,
num_workers = 4):
# Get correct indices
num_train = len(sorted(img for img in os.listdir(input_img_folder)))
indices = list(range(num_train))
seed(128381)
indices = sample(indices, len(indices))
split = int(np.floor(0.15 * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
#set up datasets
train_dataset = RSDataset(
"./data/train_images",
"./data/train_labels",
mode = "train",
)
val_dataset = RSDataset(
"./data/train_images",
"./data/train_labels",
mode="valid",
)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=True
)
return train_loader, valid_loader