keras 自定义ImageDataGenerator用于多标签分类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/w5688414/article/details/84593705

感想

keras提供了flow_from_directory用于单个标签分类,但是对图片的多标签分类没有支持,这就需要我们自己动手实现ImageDataGenerator,我这里把我实现的用于多标签分类的自定义DataGenerator分享出来,读者可以根据自己的情况来进行修改。

数据集我用的是经过整理了之后的NUS-WIDE数据集,下载地址为:https://download.csdn.net/download/w5688414/10816132

我的数据集放在一个txt文档里面的,我这里展示一下我的txt文件的example:

actor\0013_1106962433.jpg* street
actor\0018_470960261.jpg* water
actor\0031_2406153318.jpg* dog garden
actor\0033_213660760.jpg* plane
actor\0034_569936943.jpg* dancing
actor\0041_2456602544.jpg* road sky street
actor\0049_2163720456.jpg* street
actor\0064_2343195092.jpg* buildings
actor\0081_2159186964.jpg* sky
actor\0211_2233188435.jpg* beach sand
actor\0213_1850366878.jpg* sun sunset
actor\0219_453334665.jpg* fox
actor\0224_954526140.jpg* street
actor\0229_433478352.jpg* sky sun sunset
actor\0231_637866194.jpg* fox tattoo
actor\0258_1053942091.jpg* beach
actor\0273_2727864150.jpg* street
actor\0279_2321264821.jpg* statue temple
actor\0284_2060799990.jpg* running
actor\0333_82685319.jpg* street
actor\0378_344147378.jpg* statue
actor\0393_173349342.jpg* flowers
actor\0435_522150101.jpg* cars tower
actor\0438_2504620853.jpg* street
actor\0448_2291046386.jpg* sky
actor\0463_2483322510.jpg* clouds sky
actor\0485_292906123.jpg* road vehicle
actor\0491_335496963.jpg* police road street toy train
actor\0495_870673543.jpg* running
actor\0530_2568827539.jpg* book

可以看到*左边为图片的路径,右边为图片所对应的标签,然后我给每个标签编了一个号,命名为word_id.txt:

0 dog
1 clouds
2 tree
3 garden
4 dancing
5 toy
6 fox
7 ocean
8 tower
9 police
10 lake
11 mountain
12 fish
13 town
14 reflection
15 water
16 rocks
17 animal
18 temple
19 bear
20 grass
21 sun
22 beach
23 sky
24 street
25 snow
26 vehicle
27 birds
28 plane
29 book
30 sand
31 road
32 statue
33 bridge
34 cars
35 cat
36 flowers
37 military
38 buildings
39 airport
40 window
41 train
42 computer
43 tattoo
44 sunset
45 person
46 running
47 house

创建word_id.txt的代码为create_word_id.py:

txt_path='datasets81_train.txt'
with open(txt_path,'r') as f:
    datasets=f.readlines()
word_dict=set()
for file in datasets:
    data_arr=file.strip().split('*')
    img=data_arr[0]
    tag_list=data_arr[1].split(' ')
    for i in range(1,len(tag_list)):
        word_dict.add(tag_list[i].strip())

id_tag_path='word_id.txt'
with open(id_tag_path,'w') as f:
    for i,tag in enumerate(word_dict):
        f.write(str(i)+' '+tag+'\n')

最后自己定义了一个Generator:

import os
from PIL import Image
import numpy as np

BATCHSIZE=10
root_path='/home/eric/data/NUS-WIDE/image'

class data_generator:
    
    def __init__(self,file_path,_max_example,image_size,classes):
        self.load_data(file_path=file_path)
        self.index=0
        self.batch_size=BATCHSIZE
        self.image_size=image_size
        self.classes=classes
        self.load_images_labels(_max_example)
        self.num_of_examples=_max_example
        
    def load_data(self,file_path):
        with open(file_path,'r') as f: 
            self.datasets=f.readlines()
    def load_images_labels(self,_max_example):
        images=[]
        labels=[]
        for i in range(0,len(self.datasets[:_max_example])):
            data_arr=self.datasets[i].strip().split('*')
            image_path=os.path.join(root_path,data_arr[0]).replace("\\", "/")
            img=Image.open(image_path)
            img = img.resize((self.image_size[0], self.image_size[1]),Image.ANTIALIAS)
            img=np.array(img)
            images.append(img)
            tags=data_arr[1].split(' ')
            label=np.zeros((self.classes))
            for i in range(1,len(tags)):
        #         print(word_id[tags[i]])
                id=int(word_id[tags[i]])
                label[id]=1
            labels.append(label)
        self.images=images
        self.labels=labels
    def get_mini_batch(self):
        while True:
            batch_images=[]
            batch_labels=[]
            for i in range(self.batch_size):
                if(self.index==len(self.images)):
                    self.index=0
                batch_images.append(self.images[self.index])
                batch_labels.append(self.labels[self.index])
                self.index+=1
            batch_images=np.array(batch_images)
            batch_labels=np.array(batch_labels)
            yield batch_images,batch_labels

id_tag_path='word_id.txt'
word_id={}
with open(id_tag_path,'r') as f:
    words=f.readlines()
    for item in words:
        arr=item.strip().split(' ')
        word_id[arr[1]]=arr[0]


if __name__ == "__main__":
    txt_path='datasets81_clean.txt'
    width,height=224,224
    IMAGE_SIZE=(width,height,3)
    classes=81
    train_gen=data_generator(txt_path,100,IMAGE_SIZE,classes)
    x,y=next(train_gen.get_mini_batch())
    print(x.shape)
    print(y.shape)
        

我们看看train.py的调用:

from keras.optimizers import *
from keras.callbacks import *
from keras.models import *
from DataGenerator import data_generator
from resnet50 import ResNet50
from measure import *
train_txt_path='datasets81_train.txt'
test_txt_path='datasets81_test.txt'

width,height=224,224
IMAGE_SIZE=(width,height,3)
classes=81
model_name='resnet50'
train_gen=data_generator(train_txt_path,100,IMAGE_SIZE,classes)
val_gen=data_generator(test_txt_path,100,IMAGE_SIZE,classes)

model = ResNet50.resnet(IMAGE_SIZE,classes=classes)
model.summary()


save_path=os.path.join('trained_model',model_name)
if(not os.path.exists(save_path)):
    os.makedirs(save_path)
tensorboard = TensorBoard(log_dir='./logs/{}'.format(model_name), batch_size=train_gen.batch_size)
model_names = (os.path.join(save_path,model_name+'.{epoch:02d}-{val_acc:.4f}.hdf5'))
model_checkpoint = ModelCheckpoint(model_names,
                                    monitor='val_acc',
                                    verbose=1,
                                    save_best_only=True,
                                    save_weights_only=False)
reduce_learning_rate = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                         patience=5, verbose=1)
callbacks = [model_checkpoint,reduce_learning_rate,tensorboard]



model.compile(optimizer = 'adam',
           loss='binary_crossentropy',
           metrics=['accuracy',fmeasure,recall,precision])


steps=train_gen.num_of_examples//train_gen.batch_size
epochs=50
model.fit_generator(generator=train_gen.get_mini_batch(augment=True),steps_per_epoch=steps,
       epochs=epochs,
       callbacks=callbacks,
       validation_data=val_gen.get_mini_batch(),
       validation_steps=val_gen.num_of_examples // val_gen.batch_size,
       verbose=1)

其中的precision,recall , fmeasure的代码如下,这是从某人分享中截取出来的,具体出处未知了哈measure.py:

import keras.backend as K

def precision(y_true, y_pred):
    # Calculates the precision
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def recall(y_true, y_pred):
    # Calculates the recall
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def fbeta_score(y_true, y_pred, beta=1):
    # Calculates the F score, the weighted harmonic mean of precision and recall.
    if beta < 0:
        raise ValueError('The lowest choosable beta is zero (only precision).')
    
    # If there are no true positives, fix the F score at 0 like sklearn.
    if K.sum(K.round(K.clip(y_true, 0, 1))) == 0:
        return 0

    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    bb = beta ** 2
    fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
    return fbeta_score

def fmeasure(y_true, y_pred):
    # Calculates the f-measure, the harmonic mean of precision and recall.
    return fbeta_score(y_true, y_pred, beta=1)

我这里用到了resnet50的代码,这里也分享出来:

import os

from keras.layers import (
    Conv2D, BatchNormalization,
    MaxPooling2D, ZeroPadding2D, AveragePooling2D,
    add, Dense, Flatten,Input
)
from keras.layers.advanced_activations import PReLU
from keras.models import Model, load_model
# from utils import load_mnist


class ResNet50():

    @staticmethod
    def resnet(input_shape,classes=100,weights="trained_model/resnet.hdf5"):
        """Inference function for ResNet

        y = resnet(X)

        Parameters
        ----------
        input_tensor : keras.layers.Input

        Returns
        ----------
        y : softmax output
        """
        def name_builder(type, stage, block, name):
            return "{}{}{}_branch{}".format(type, stage, block, name)

        def identity_block(input_tensor, kernel_size, filters, stage, block):
            F1, F2, F3 = filters

            def name_fn(type, name):
                return name_builder(type, stage, block, name)

            x = Conv2D(F1, (1, 1), name=name_fn('res', '2a'))(input_tensor)
            x = BatchNormalization(name=name_fn('bn', '2a'))(x)
            x = PReLU()(x)

            x = Conv2D(F2, kernel_size, padding='same', name=name_fn('res', '2b'))(x)
            x = BatchNormalization(name=name_fn('bn', '2b'))(x)
            x = PReLU()(x)

            x = Conv2D(F3, (1, 1), name=name_fn('res', '2c'))(x)
            x = BatchNormalization(name=name_fn('bn', '2c'))(x)
            x = PReLU()(x)

            x = add([x, input_tensor])
            x = PReLU()(x)

            return x

        def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
            def name_fn(type, name):
                return name_builder(type, stage, block, name)

            F1, F2, F3 = filters

            x = Conv2D(F1, (1, 1), strides=strides, name=name_fn("res", "2a"))(input_tensor)
            x = BatchNormalization(name=name_fn("bn", "2a"))(x)
            x = PReLU()(x)

            x = Conv2D(F2, kernel_size, padding='same', name=name_fn("res", "2b"))(x)
            x = BatchNormalization(name=name_fn("bn", "2b"))(x)
            x = PReLU()(x)

            x = Conv2D(F3, (1, 1), name=name_fn("res", "2c"))(x)
            x = BatchNormalization(name=name_fn("bn", "2c"))(x)

            sc = Conv2D(F3, (1, 1), strides=strides, name=name_fn("res", "1"))(input_tensor)
            sc = BatchNormalization(name=name_fn("bn", "1"))(sc)

            x = add([x, sc])
            x = PReLU()(x)

            return x
        input_tensor = Input(shape=input_shape)
        net = ZeroPadding2D((3, 3))(input_tensor)
        net = Conv2D(64, (7, 7), strides=(2, 2), name="conv1")(net)
        net = BatchNormalization(name="bn_conv1")(net)
        net = PReLU()(net)
        net = MaxPooling2D((3, 3), strides=(2, 2))(net)

        net = conv_block(net, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
        net = identity_block(net, 3, [64, 64, 256], stage=2, block='b')
        net = identity_block(net, 3, [64, 64, 256], stage=2, block='c')

        net = conv_block(net, 3, [128, 128, 512], stage=3, block='a')
        net = identity_block(net, 3, [128, 128, 512], stage=3, block='b')
        net = identity_block(net, 3, [128, 128, 512], stage=3, block='c')
        net = identity_block(net, 3, [128, 128, 512], stage=3, block='d')

        net = conv_block(net, 3, [256, 256, 1024], stage=4, block='a')
        net = identity_block(net, 3, [256, 256, 1024], stage=4, block='b')
        net = identity_block(net, 3, [256, 256, 1024], stage=4, block='c')
        net = identity_block(net, 3, [256, 256, 1024], stage=4, block='d')
        net = identity_block(net, 3, [256, 256, 1024], stage=4, block='e')
        net = identity_block(net, 3, [256, 256, 1024], stage=4, block='f')
        net = AveragePooling2D((2, 2))(net)

        net = Flatten()(net)
        net = Dense(classes, activation="sigmoid")(net)
        model = Model(input_tensor, net, name='model')
        if os.path.isfile(weights):
            model.load_weights(weights)
            print("Model loaded")
        else:
            print("No model is found")

        return model

# img_width=128
# img_height=128
# charset_size=6941
# model = ResNet50.resnet(input_shape=(img_width,img_height,3), classes=charset_size)
# model.summary()

然后就可以运行了。

猜你喜欢

转载自blog.csdn.net/w5688414/article/details/84593705