[机器学习与深度学习] - No.6 ImageNet数据集预处理方式

在之前工作中,遇到了一个问题,在Google和Github的帮助下解决了,总结一下防止以后再次遇到。

问题描述: 当我们使用Keras自带的VGG16,VGG19等模型在ImageNet上做图像识别的时候 ,Top-1和Top-5的精度都会偏低一些,相对于Keras公布出来的精度。例如,我自己测的VGG16的精度只有64%(Top-1),但是Keras上公布的精度为71.3%。后来在Google上调研了很久,发现是数据预处理的问题。

我使用的加载图片的方式(部分代码):

# loaded image path and label previously
image_list = []
from keras.preprocessing import image
for img_name in images_names:
    tmp_img = image.load_img(img_name, target_size=(224, 224))
    tmp_img = image.img_to_array(tmp_img)
    image_list.append(tmp_img)
image_arrays = np.array(image_list)
print(image_arrays.shape)

# prediction
x_validation=keras.applications.vgg16.preprocess_input(image_arrays)
vgg16_scores1 = vgg16_model.evaluate(x_validation,y_val_oc)
print(vgg16_scores1)

50000/50000 [==============================] - 158s 3ms/step
[1.5069317724227904, 0.64274]

这也是Keras官网中给的示例中的方法。可以看出,用这种方式读取图片,模型的精度只有64%。

后来我在Github 的这篇issue中找到了答案:如果想要获得Keras官网中的精度,需要将ImageNet数据集按照如下方式进行预处理

  • Resize the shorter side of each image to 256 (将图像的短边长度调整为256)
  • Resize the longer side to maintain the aspect ratio (调整长边尺寸以保持长宽比)
  • Central 224x224 crop (从图片中央截取 224x224的图片)

此种方式的数据预处理代码如下所示:

X_val = np.zeros((len(images_names), 224, 224, 3), dtype=np.float32)
# get_images
for i in range(len(images_names)):
    if i %2000 == 0:
        print("%d/%d" % (i, len(images_names)))
    
    # Load (as BGR)
    img = cv2.imread(images_names[i])
    
    # Resize
    height, width, _ = img.shape
    new_height = height * 256 // min(img.shape[:2])
    new_width = width * 256 // min(img.shape[:2])
    img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
    
    # Crop
    height, width, _ = img.shape
    startx = width//2 - (224//2)
    starty = height//2 - (224//2)
    img = img[starty:starty+224,startx:startx+224]
    assert img.shape[0] == 224 and img.shape[1] == 224, (img.shape, height, width)
    
    # Save (as RGB)
    X_val[i,:,:,:] = img[:,:,::-1]
print(X_val[:10])

使用这种方式的模型精度如下所示:

50000/50000 [==============================] - 170s 3ms/step
[1.1673228887557983, 0.71268]

上面提到的issue中的完整的代码链接在这里

发布了118 篇原创文章 · 获赞 140 · 访问量 25万+

猜你喜欢

转载自blog.csdn.net/tjuyanming/article/details/105298043
今日推荐