Use override pytorch torchvision.ImageFolder of the pytorch ImageFolder

When the data is classified in training a classifier, such as dogs and cats when classification, we often use pytorch of ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

Use visible use of pytorch torchvision.ImageFolder

Here I would like to achieve is that if you want to overwrite function, which can use its features, they can realize their functions

First, the first analysis of its source code:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                          transform=transform,
                                          target_transform=target_transform)
        self.imgs = self.samples

ImageFolder code is very simple, mainly inherited DatasetFolder :

DEF has_file_allowed_extension (filename, Extensions):
     "" " to see whether the file is supported by the extensible type 

    Args: 
        filename ( String ): the file path 
        extensions (iterable of strings): Scalable type list that can accept image file types 

    Returns: 
        BOOL : True IF at The filename ends with Extensions One of GIVEN
     "" "
     filename_lower = filename.lower ()
     return the any (filename_lower.endswith (EXT) for EXT in Extensions) # returns True or False list 


def make_dataset (dir, class_to_idx, extensions ): 
    "" "
         returns the form [(image path, the category corresponding to the image index value), (), ...]
     " ""
    images =[] 
    The dir = os.path.expanduser (the dir)
     for target in the sorted (class_to_idx.keys ()): 
        D = the os.path.join (the dir, target)
         IF Not os.path.isdir (D):
             Continue 

        for the root , _, fnames in sorted (os.walk (d)): # traverse layers of folders, returns the current folder path, all existing files folder name, file name exists all
             for fname in sorted (fnames):
                 IF has_file_allowed_extension (fname, extensions): Check if the file is supported by the extensible type, is the continued 
                    path = os.path.join (root, fname) 
                    Item = (path, class_to_idx[target])
                    images.append(item)

    return images

class DatasetFolder(data.Dataset):
    """A generic data loader where the samples are arranged in this way: ::

        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext

        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext

    Args:
        root (string): 根目录路径
        loader (callable): the given path to load the sample callable functions 
        Extensions (List [ String ]): a list of extensible type, i.e. can accept image file types. 
        transform (Callable, optional): used to transform the sample function, and then return after the sample transform version 
            Eg, `` transforms.RandomCrop`` for ImagesRF Royalty Free. 
        target_transform (a Callable, optional): transform function is used to sample labels of 

     the Attributes: 
        classes (list): category name list 
        class_to_idx (dict): item (class_name, class_index) dictionary, such as { ' CAT ' : 0 , ' Dog ' : . 1 } 
        the samples (list): (sample path, class_index) a list of tuples, i.e., (the sample path, the index category)
        targets (list): in each image data set based index, the list 
    "" "
 
    DEF the __init __ (Self, the root, Loader, Extensions, Transform = None, target_transform = None): 
        classes, class_to_idx = self._find_classes (the root ) # give class names and class index, such as [ ' CAT ' , ' Dog ' ] and { ' CAT ' : 0 , ' Dog ' : . 1 } 
        # returns the form [(image path, corresponding to the image category index value), (), ...], i.e., each image is labeled 
        the Samples = make_dataset (the root, class_to_idx, Extensions) 
         IF len (the Samples) == 0 :
            raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                               "Supported extensions are: " + ",".join(extensions)))

        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for S in the list of samples] # image of all classes index composed of 

        self.transform = the Transform 
        self.target_transform = target_transform 

    DEF _find_classes (Self, dir): 
        "" "
         in the data set to find class folder. 

        Args: 
            dir ( String ) : root path 

        returns: 
            returns tuple: (classes, class_to_idx) i.e. (class name, class index), where classes i.e. the corresponding directory name, such as [ ' CAT ' , ' Dog ' ]; class_to_idx of the form {class name :} class index dictionary, such as { ' CAT ' : 0 , ' Dog ' :. 1 }. 

        Ensures: 
            to ensure that no other class class name directory is a subdirectory 
        "" "
         IF The sys.version_info> = ( . 3 , . 5 ): 
            # Faster and Available in the Python 3.5 of and above 
            classes = [d.name for D in os.scandir (dir) IF d.is_dir ()] to obtain the root directory # dir name all the subdirectories of the first layer
         the else : 
            classes = [D for D in the os.listdir (dir) IF os.path.isdir (OS. path.join (dir, d))] # and effect as above, but different versions of different methods 
        classes.sort () # class name and then be sorted
        class_to_idx = {classes [I]: I for I in Range (len (classes))} # class name and then the index values correspond to the corresponding dictionary, such as { ' CAT ' : 0 , ' Dog ' : . 1 }
         return classes, class_to_idx # then returns the class name and the class index 

    DEF the __getitem __ (Self, index): 
        "" "
         Args: 
            index ( int ): index 

        returns: 
            tuple: (the Sample, target) the WHERE target IS class_index of at The target class .
         " ""
         path, target = self.samples[index]
        sample = self.loader(path) # 加载图片
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

At this point you want to overwrite ImageFolder, code:

class CustomImageFolder (ImageFolder):
     "" "
         In order to obtain the image information and index values (where a randomly selected) two graphs
     " ""
     DEF the __init __ (Self, the root, Transform = None): 
        Super (CustomImageFolder, Self) the init __ .__ (the root, Transform) 
        self.indices = Range (len (Self)) # length of the folder 

    DEF the __getitem __ (Self, index1,): 
        index2 = the random.choice (self.indices) from # [ 0 , indices] in a random number, a randomly selected to FIG 

        path1 = self.imgs [index1,] [ 0 ] # self.imgs this case equal self.samples, i.e., the content of [(image path, the category corresponding to the image index value ), (), ...] 
        Label1 = self.imgs [index1,] [ . 1 ] 
        path2= self.imgs[index2][0]
        label2 = self.imgs[index2][1]

        img1 = self.loader(path1)
        img2 = self.loader(path2)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, label1, label2

 

Guess you like

Origin www.cnblogs.com/wanghui-garcia/p/11514368.html