【天池比赛服装关键点检测fashionAI_landmark_detect——关于导入图片解决内存不够问题】

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/m0_37922734/article/details/80355078

由于图像的数据量还不小,总共有31631张图片,并且每张都是512*512*3大小的图片,无法一次性导入内存中,因此采用类的方法,实现每次只导入内存一个batch的图片。
需要设计两个类,第一个类是ImageData:
class ImageData:
    """
    """

    def __init__(self, img_ids, img_dir,
                 source_shape=None,
                 target_shape=None,
                 padding='edge',
                 normalize=False):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.source_shape = source_shape[:2]
        self.target_shape = target_shape[:2]
        self.padding = padding
        self.normalize = normalize

    def __len__(self):
        return len(self.img_ids)
#这中间的一些调用方法就不展示了
.....
.....   
     def __getitem__(self, indices):
        if isinstance(indices, np.ndarray) or isinstance(indices, slice):
            imgs = [self._read_image(img_id) for img_id in self.img_ids[indices]]
            imgs = np.array(imgs)
            return imgs
        else:
            return self._read_image(self.img_ids[indices])

第二个类是BatchSequence类
class BatchSequence(Sequence):
    def __init__(self, X, y=None, num_samples=None, batch_size=64, shuffle=True, global_indices=None):
        assert isinstance(num_samples, int), 'num_samples should be given'
        self.X = X
        self.y = y
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.shuffle = shuffle
        if global_indices is None:
            self.global_indices = np.arange(num_samples)
        else:
            self.global_indices = global_indices
        if shuffle:
            np.random.shuffle(self.global_indices)

    def __len__(self):
        return int(np.ceil(self.num_samples / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.global_indices)

    def __getitem__(self, batch_id):
        batch_data_index = \
            self.global_indices[batch_id * self.batch_size: (batch_id + 1) * self.batch_size]
        if isinstance(self.X, tuple) or isinstance(self.X, list):
            batch_x = [x[batch_data_index] for x in self.X]
        else:
            batch_x = self.X[batch_data_index]

        if self.y is None:
            return batch_x

        if isinstance(self.y, tuple) or isinstance(self.y, list):
            batch_y = [y[batch_data_index] for y in self.y]
        else:
            batch_y = self.y[batch_data_index]

        return batch_x, batch_y
 
 


第一个类实例化之后,通过切片操作获取图片此时就会把图片读入内存。例如:
 
  
image = ImageData(train_img_ids, TRAIN_DATA_DIR, ...... )
image[0] 此时就会从磁盘中读取一张图片到内存中 image[:10] 此时就会从磁盘中读取10张图片到内存中

第二个类实现的是每次指定一个batch大小的图片
 
 
train_data = BatchSequence(X, y, num_samples=num_train_samples, batch_size=batch_size,......)
train_data[0] 此时就会从磁盘中读取一个batch大小的图片到内存中

python和Java一样都是垃圾自动回收,不需要我们显示地销毁对象。执行del obj时会调用对象的__del__方法,这样对象的引用计数会减1,当对象的引用计数为0时,对象就会被销毁,内存就会被回收。

猜你喜欢

转载自blog.csdn.net/m0_37922734/article/details/80355078
今日推荐