tensorflow / modelo de código-fonte da biblioteca Deeplabv3 + implementação (quatro) - prever imagens simples e em lote


Depois que o modelo deeplabv3 + é treinado, ele pode ser usado para prever suas próprias imagens. O processo de treinamento do modelo pode ser visto no meu artigo anterior.

1. Preveja uma única imagem

Prever imagens locais
Primeiro salve as imagens e depois consulte a demonstração oficial deeplabv3 +: deeplab_demo.ipynp para escrever seu próprio programa. Não vou colocar o código aqui, basta olhar para as várias imagens de previsão.
A previsão de imagens online
requer o url da imagem fornecida, basta olhar para o seguinte.

2. Preveja várias fotos

import os
from io import BytesIO
from six.moves import urllib

from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
import datetime


class DeeplabModel(object):
    """class to load deeplab model and run inference."""

    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, model_path):
        """creates and loads pretrained deeplab model."""

        self.graph = tf.Graph()
        # Extract frozen graph
        with open(model_path + "frozen_inference_graph.pb", "rb") as f:
            graph_def = tf.GraphDef.FromString(f.read())

        if graph_def is None:
            raise RuntimeError('cannot find inference graph')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)

    def run(self, image):
        """runs inference on a single image.

        Args:
            image:A PIL.Image object,raw input image.

        Returns:
            resized_image:RGB image resized from original input image.
            seg_map:Segmentation map
            """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={
    
    self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
       # print('deeplab model finished')
        return resized_image, seg_map


def create_pascal_label_colormap():
    """creates a label colormap used in PASCAL VOC Segmentation benchmark.

    Returns:
        A colormap for visualizing segmentation results"""
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)
    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3
    return colormap


def label_to_color_image(label):
    """Adds color defined by dataset colormap to label"""
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')
    return colormap[label]


def vis_segmentation(image, seg_map):
    """可视化分割图像
    Para:
        image:原图
        seg_map:分割好的图像
    """
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 3, width_ratios=[10, 10, 10])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')
    plt.show()


LABEL_NAMES = np.asarray([
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
])


def load_image(image_path):
    """从本地文件夹中加载图像"""
    img = Image.open(image_path)
    return img


def save_seg_image(seg_map, save_logdir):
    """save segmentation image"""
    colored_image = label_to_color_image(seg_map).astype(np.uint8)
    colored_image = Image.fromarray(colored_image)
    colored_image.save(save_logdir)


export_model_path = '/home/hy/software/models/research/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/export/'
MODEL = DeeplabModel(export_model_path)


def run_visualization(image_path, saved_path, mode=False):
    """inferences deeplab model and visualizes results
    Args:
        image_path:path where image stores or url where can load image
        saved_path:path where segmentation images store
        mode:chose load image from local path or online, default False refers local path """
    start = datetime.datetime.now()
    if mode:
        try:
            f = urllib.request.urlopen(image_path)
            jpeg_str = f.read()
            original_im = Image.open(BytesIO(jpeg_str))
        except IOError:
            print('Cannot retrieve image.Please check url:' + image_path)
            return
        resized_im, seg_map = MODEL.run(original_im)
    else:
        if not os.path.exists(image_path) or not os.path.exists(saved_path):
            print('Error:cannot find image path!')
            return
        image_lists = os.listdir(image_path)
        length = len(image_lists)
        for i in range(0, length):
            print('predict the %dth image' % i)
            original_im = load_image(os.path.join(image_path, image_lists[i]))
        # print('running deeplab on image %s...' % image_path)
            resized_im, seg_map = MODEL.run(original_im)
            save_seg_image(seg_map, os.path.join(saved_path, image_lists[i]))
    end = datetime.datetime.now()
    print(end - start)
    vis_segmentation(resized_im, seg_map)


# IMAGE_URL = 'https://ss3.bdstatic.com/70cFv8Sh_Q1YnxGkpoWK1HF6hhy/it/u=3731733193,393708434&fm=26&gp=0.jpg'
img_path = '/home/hy/template/pictures'
save_path = '/home/hy/template/seg_map'
run_visualization(img_path, save_path)

Mas o código acima apenas lê as imagens em loop para processamento, se houver uma forma mais eficiente, deixe-me um recado!

3. Imagem do efeito

Insira a descrição da imagem aquiInsira a descrição da imagem aqui

Acho que você gosta

Origin blog.csdn.net/qq_43265072/article/details/105770287
Recomendado
Clasificación