AI实战:YOLK: Keras Object Detection API

YOLK

YOLK为You Only Look Keras的缩写,是Keras的一站式对象检测API。通过几行代码,可以设置性能最佳的模型之一并将其应用于自己的数据集,轻松地训练自己的目标检测模型。

Github地址

https://github.com/KerasKorea/KerasObjectDetector

安装(Linux)

# Download YOLK API
  $ git clone https://github.com/KerasKorea/KerasObjectDetector.git
  $ cd KerasObjectDetector

  # If there is no 'setuptools' in docker, please download this package.
  # pip install setuptools
  # install library
  $ apt-get install libatlas-base-dev libxml2-dev libxslt-dev python-tk
  
  # build setup codes
  # ./KerasObjectDetector
  $ python setup.py install

Docker中使用

#  pull yolk docker image
  $ docker pull kerasyolk/yolk

  # run yolk
  $ docker run --name=yolk -Pit -p 8888:8888 -p 8022:22 kerasyolk/yolk:latest

  # running jupyter-notebook
  $ jupyter-notebook

支持的检测模型

  • RetinaNet
  • SSD
  • YOLOv3

使用示例:YOLOv3

  • YOLOv3_Prediction.py
import tensorflow as tf

from PIL import Image

from keras_yolov3.train import get_anchors, get_classes
from keras_yolov3.yolov3_class import YOLOv3
from keras_yolov3.utils_img import save_img

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

classes_path = '../keras_yolov3/model_data/coco_classes.txt'
anchors_path = '../keras_yolov3/model_data/yolo_anchors.txt'
class_names = get_classes(classes_path)
num_classes = len(class_names)
anchors = get_anchors(anchors_path)

yolov3 = YOLOv3(anchors, num_classes)

model_path = '../keras_yolov3/model_data/yolo_weights.h5'
yolov3.model.load_weights(model_path)

img_path = "000000008021.jpg"
image = Image.open(img_path)

out_boxes, out_scores, out_classes = yolov3.predict_detection([image])

save_img("result.jpg", image, class_names, out_boxes[0], out_scores[0], out_classes[0])
  • YOLOv3_Training.py
from keras.optimizers import Adam

from keras_yolov3.train import get_anchors, get_classes, data_generator_wrapper
from keras_yolov3.yolov3_class import YOLOv3

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

classes_path = '../keras_yolov3/model_data/coco_classes.txt'
anchors_path = '../keras_yolov3/model_data/yolo_anchors.txt'
class_names = get_classes(classes_path)
num_classes = len(class_names)
anchors = get_anchors(anchors_path)
num_anchors = len(anchors)

yolov3 = YOLOv3(anchors, num_classes)

model_path = '../keras_yolov3/model_data/yolo_weights.h5'
yolov3.model.load_weights(model_path, by_name=True, skip_mismatch=True)

annotation_path = '../keras_yolov3/model_data/train.txt'

with open(annotation_path) as f:
    lines = f.readlines()

num_train = len(lines)
batch_size = 32

yolov3.model.compile(optimizer=Adam(lr=1e-3), loss={'yolo_loss': lambda y_true, y_pred: y_pred})

yolov3.model.fit_generator(data_generator_wrapper(lines, batch_size, yolov3.input_shape, anchors, num_classes),
                steps_per_epoch=max(1, num_train // batch_size),
                epochs=50,
                initial_epoch=0)

猜你喜欢

转载自blog.csdn.net/zengNLP/article/details/104720088
今日推荐