python装饰器应用之keras数据生成器

这里仅仅是装饰器的一个简单应用,与平常唯一不同的地方就是我把装饰器写在了类的外部,而且装饰器内部还调用了类中的方法。只要在wrapper参数里加上self即可。还是直接看代码吧,这是我用于ocr训练的数据生成器。

def generate(func):
    def wrapper(self, *args, **kwargs):
        index_all, batch_size = func(self, *args, **kwargs)
        i, n = 0, len(index_all)
        while True:
            if i + batch_size >= n:
                np.random.shuffle(index_all)
                i = 0
                continue
            batch_x, batch_y = [], []
            batch_input_length = np.ones(batch_size) * (max_img_weigth // 8)
            batch_label_length = []
            for j in range(i, i + batch_size):
                x, y = self.get_img_data(index_all[j])
                batch_x.append(x)
                batch_y.append(y)
                batch_label_length.append(self.label_length[j])
            i += batch_size
            yield [np.array(batch_x),
                   np.array(batch_y),
                   batch_input_length,
                   np.array(batch_label_length)], np.ones(batch_size)
    return wrapper


class ChineseDataset(object):
    def __init__(self):
        mat_annotation = loadmat(label_mat)
        self.img_dir = img_dir
        self.filenames = mat_annotation['img']
        self.labels = mat_annotation['label']
        self.label_length = mat_annotation['label_length'][0]

    def get_train_num(self):
        return int(len(self.filenames) * 0.8)

    def get_valid_num(self):
        return len(self.filenames) - int(len(self.filenames) * 0.8)

    def get_img_data(self, index):
        img = cv2.imread(os.path.join(self.img_dir, self.filenames[index]))
        img = cv2.resize(img, (max_img_weigth, max_img_height)) / 255.
        label = one_hot(self.labels[index])
        return img, label

    @generate
    def gen_train(self, batch_size):
        index_all = list(range(int(len(self.filenames) * 0.8)))
        return index_all, batch_size

    @generate
    def gen_valid(self, batch_size):
        index_all = list(range(int(len(self.filenames) * 0.8), len(self.filenames)))
        return index_all, batch_size
发布了63 篇原创文章 · 获赞 55 · 访问量 10万+

猜你喜欢

转载自blog.csdn.net/qq_36810544/article/details/104288294