Semantic Segmentation data generator dataloader (pytorch Edition)

The basic structure of the data set

You can refer to the official documentation Web documantation . There are three main categories: Dataset, Sampler and DataLoader.

  • Dataset:
    the representative dataset abstract class; all other data sets should inherit it. All subclasses should override len (providing data set size) and getItem (support ranges from 0 to len (Self) shaping index).

  • Sampler:
    All sampler reference class; each subclass must provide a sampler iter method, a method of indexing a set of elements through the data, and returns an iterator length len method.

  • DataLoader:
    a combination of data sets and samplers, and provide a single process or processes on the data iterator set.

Simple dataset class:

train_images_path = "./data/train_images"
train_labels_path = "./data/train_labels"

class RSDataset(Dataset):
    def __init__(self, input_root, mode="train", debug = False):
        super().__init__()
        self.input_root = input_root
        self.mode = mode
        if debug == False:
            self.input_ids = sorted(img for img in os.listdir(self.input_root))
        else:
            self.input_ids = sorted(img for img in os.listdir(self.input_root))[:500]
        
        self.mask_transform = transforms.Compose([
            transforms.Lambda(to_monochrome),
            transforms.Lambda(to_tensor),
        ])
            
        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.transform = DualCompose([
                RandomFlip(),
                RandomRotate90(),
                Rotate(),
                Shift(),
            ])
        
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        # at this point all transformations are applied and we expect to work with raw tensors
        imageName = os.path.join(self.input_root,self.input_ids[idx])
        image = np.array(cv2.imread(imageName), dtype=np.float32)
        mask = np.array(cv2.imread(imageName.replace("train_images", "train_labels")))/255
        h, w, c = image.shape
        mask1 = np.zeros((h, w), dtype=int)

        if self.mode == "train":
            image, mask  =  self.transform(image, mask)
            mask1 = mask[:,:,0]
            return self.image_transform(image), self.mask_transform(mask1)
        else:
            mask1 = mask[:,:,0]
            return self.image_transform(image), self.mask_transform(mask1)


###划分训练集和验证集
def build_loader(input_img_folder = "./data/train_images",
                 batch_size = 16,
                 num_workers = 4):
    # Get correct indices
    num_train = len(sorted(img for img in os.listdir(input_img_folder)))
    indices = list(range(num_train))
    seed(128381)
    indices = sample(indices, len(indices))
    split = int(np.floor(0.15 * num_train))

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    #set up datasets
    train_dataset = RSDataset(
        "./data/train_images",
        "./data/train_labels",
        mode = "train",
    )

    val_dataset = RSDataset(
        "./data/train_images",
        "./data/train_labels",
        mode="valid",
    )

    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=True
    )

    valid_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, valid_loader
Published 33 original articles · won praise 3 · Views 5547

Guess you like

Origin blog.csdn.net/weixin_42990464/article/details/104197638