When the data is classified in training a classifier, such as dogs and cats when classification, we often use pytorch of ImageFolder:
CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
Use visible use of pytorch torchvision.ImageFolder
Here I would like to achieve is that if you want to overwrite function, which can use its features, they can realize their functions
First, the first analysis of its source code:
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp'] class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, transform=None, target_transform=None, loader=default_loader): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs = self.samples
ImageFolder code is very simple, mainly inherited DatasetFolder :
DEF has_file_allowed_extension (filename, Extensions): "" " to see whether the file is supported by the extensible type Args: filename ( String ): the file path extensions (iterable of strings): Scalable type list that can accept image file types Returns: BOOL : True IF at The filename ends with Extensions One of GIVEN "" " filename_lower = filename.lower () return the any (filename_lower.endswith (EXT) for EXT in Extensions) # returns True or False list def make_dataset (dir, class_to_idx, extensions ): "" " returns the form [(image path, the category corresponding to the image index value), (), ...] " "" images =[] The dir = os.path.expanduser (the dir) for target in the sorted (class_to_idx.keys ()): D = the os.path.join (the dir, target) IF Not os.path.isdir (D): Continue for the root , _, fnames in sorted (os.walk (d)): # traverse layers of folders, returns the current folder path, all existing files folder name, file name exists all for fname in sorted (fnames): IF has_file_allowed_extension (fname, extensions): Check if the file is supported by the extensible type, is the continued path = os.path.join (root, fname) Item = (path, class_to_idx[target]) images.append(item) return images class DatasetFolder(data.Dataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): 根目录路径 loader (callable): the given path to load the sample callable functions Extensions (List [ String ]): a list of extensible type, i.e. can accept image file types. transform (Callable, optional): used to transform the sample function, and then return after the sample transform version Eg, `` transforms.RandomCrop`` for ImagesRF Royalty Free. target_transform (a Callable, optional): transform function is used to sample labels of the Attributes: classes (list): category name list class_to_idx (dict): item (class_name, class_index) dictionary, such as { ' CAT ' : 0 , ' Dog ' : . 1 } the samples (list): (sample path, class_index) a list of tuples, i.e., (the sample path, the index category) targets (list): in each image data set based index, the list "" " DEF the __init __ (Self, the root, Loader, Extensions, Transform = None, target_transform = None): classes, class_to_idx = self._find_classes (the root ) # give class names and class index, such as [ ' CAT ' , ' Dog ' ] and { ' CAT ' : 0 , ' Dog ' : . 1 } # returns the form [(image path, corresponding to the image category index value), (), ...], i.e., each image is labeled the Samples = make_dataset (the root, class_to_idx, Extensions) IF len (the Samples) == 0 : raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" "Supported extensions are: " + ",".join(extensions))) self.root = root self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for S in the list of samples] # image of all classes index composed of self.transform = the Transform self.target_transform = target_transform DEF _find_classes (Self, dir): "" " in the data set to find class folder. Args: dir ( String ) : root path returns: returns tuple: (classes, class_to_idx) i.e. (class name, class index), where classes i.e. the corresponding directory name, such as [ ' CAT ' , ' Dog ' ]; class_to_idx of the form {class name :} class index dictionary, such as { ' CAT ' : 0 , ' Dog ' :. 1 }. Ensures: to ensure that no other class class name directory is a subdirectory "" " IF The sys.version_info> = ( . 3 , . 5 ): # Faster and Available in the Python 3.5 of and above classes = [d.name for D in os.scandir (dir) IF d.is_dir ()] to obtain the root directory # dir name all the subdirectories of the first layer the else : classes = [D for D in the os.listdir (dir) IF os.path.isdir (OS. path.join (dir, d))] # and effect as above, but different versions of different methods classes.sort () # class name and then be sorted class_to_idx = {classes [I]: I for I in Range (len (classes))} # class name and then the index values correspond to the corresponding dictionary, such as { ' CAT ' : 0 , ' Dog ' : . 1 } return classes, class_to_idx # then returns the class name and the class index DEF the __getitem __ (Self, index): "" " Args: index ( int ): index returns: tuple: (the Sample, target) the WHERE target IS class_index of at The target class . " "" path, target = self.samples[index] sample = self.loader(path) # 加载图片 if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str
At this point you want to overwrite ImageFolder, code:
class CustomImageFolder (ImageFolder): "" " In order to obtain the image information and index values (where a randomly selected) two graphs " "" DEF the __init __ (Self, the root, Transform = None): Super (CustomImageFolder, Self) the init __ .__ (the root, Transform) self.indices = Range (len (Self)) # length of the folder DEF the __getitem __ (Self, index1,): index2 = the random.choice (self.indices) from # [ 0 , indices] in a random number, a randomly selected to FIG path1 = self.imgs [index1,] [ 0 ] # self.imgs this case equal self.samples, i.e., the content of [(image path, the category corresponding to the image index value ), (), ...] Label1 = self.imgs [index1,] [ . 1 ] path2= self.imgs[index2][0] label2 = self.imgs[index2][1] img1 = self.loader(path1) img2 = self.loader(path2) if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, label1, label2