deeplabv3+ demo测试图像分割


 11 #!--*-- coding:utf-8 --*--
 12 
 13 # Deeplab Demo
 14 
 15 import os
 16 import tarfile
 17 
 18 from matplotlib import gridspec
 19 import matplotlib.pyplot as plt
 20 import numpy as np
 21 from PIL import Image
 22 import tempfile
 23 from six.moves import urllib
 24 
 25 import tensorflow as tf
 26 
 27 
 28 class DeepLabModel(object):
 29     """
 30     加载 DeepLab 模型;
 31     推断 Inference.
 32     """
 33     INPUT_TENSOR_NAME = 'ImageTensor:0'
 34     OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
 35     INPUT_SIZE = 513
 36     FROZEN_GRAPH_NAME = 'frozen_inference_graph'
 37 
 38     def __init__(self, tarball_path):
 39         """
 40         加载预训练模型
 41         """
 42         self.graph = tf.Graph()
 43 
 44         graph_def = None
 45         # Extract frozen graph from tar archive.
 46         tar_file = tarfile.open(tarball_path)
 47         for tar_info in tar_file.getmembers():
 48             if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
 49                 file_handle = tar_file.extractfile(tar_info)
 50                 graph_def = tf.GraphDef.FromString(file_handle.read())
 51                 break
 52 
 53         tar_file.close()
 54 
 55         if graph_def is None:
 56             raise RuntimeError('Cannot find inference graph in tar archive.')
 57 
 58         with self.graph.as_default():
 59             tf.import_graph_def(graph_def, name='')
 60 
 61         self.sess = tf.Session(graph=self.graph)
 62 
 63 
 64     def run(self, image):
 65         """
 66 
 68         Args:
 69         image:  转换为PIL.Image 类,不能直接用图片,原始图片
 70 
 71         Returns:
 72         resized_image: RGB image resized from original input image.
 73         seg_map: Segmentation map of `resized_image`.
 74         """
 75         width, height = image.size
 76         resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
 77         target_size = (int(resize_ratio * width), int(resize_ratio * height))
 78         resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
 79         batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
 80                                       feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
 81         seg_map = batch_seg_map[0]
 82         return resized_image, seg_map
 83 
 84 
 85 def create_pascal_label_colormap():
 86     """
 87     Creates a label colormap used in PASCAL VOC segmentation benchmark.
 88 
 89     Returns:
 90         A Colormap for visualizing segmentation results.
 91     """
 92     colormap = np.zeros((256, 3), dtype=int)
 93     ind = np.arange(256, dtype=int)
 94 
 95     for shift in reversed(range(8)):
 96         for channel in range(3):
 97             colormap[:, channel] |= ((ind >> channel) & 1) << shift
 98         ind >>= 3
 99 
100     return colormap
101 
102 
103 def label_to_color_image(label):
104     """
105     Adds color defined by the dataset colormap to the label.
106 
107     Args:
108         label: A 2D array with integer type, storing the segmentation label.
109 
110     Returns:
111         result: A 2D array with floating type. The element of the array
112         is the color indexed by the corresponding element in the input label
113         to the PASCAL color map.
114 
115     Raises:
116         ValueError: If label is not of rank 2 or its value is larger than color
117         map maximum entry.
118     """
119     if label.ndim != 2:
120         raise ValueError('Expect 2-D input label')
121 
122     colormap = create_pascal_label_colormap()
123 
124     if np.max(label) >= len(colormap):
125         raise ValueError('label value too large.')
126 
127     return colormap[label]
128 
129 
130 def vis_segmentation(image, seg_map, imagefile):
131     """可视化三种图像."""
132     plt.figure(figsize=(15, 5))
133     grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
134 
135     plt.subplot(grid_spec[0])
136     plt.imshow(image)
137     plt.axis('off')
138     plt.title('input image')
139 
140     plt.subplot(grid_spec[1])
141     seg_image = label_to_color_image(seg_map).astype(np.uint8)
142     # seg_image = label_to_color_image(seg_map)
143     # seg_image.save('/str(ss)+imagefile')
144     plt.imshow(seg_image)
145     plt.savefig('./'+imagefile+'.png')
146 
147     plt.axis('off')
148     plt.title('segmentation map')
149 
150     plt.subplot(grid_spec[2])
151     plt.imshow(image)
152     plt.imshow(seg_image, alpha=0.7)
153     plt.axis('off')
154     plt.title('segmentation overlay')
155 
156     unique_labels = np.unique(seg_map)
157     ax = plt.subplot(grid_spec[3])
158     plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
159     ax.yaxis.tick_right()
160     plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
161     plt.xticks([], [])
162     ax.tick_params(width=0.0)
163     plt.grid('off')
164     plt.show()
165 
166 
167 ##
168 LABEL_NAMES = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
169                           'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
170                           'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ])
171 
172 FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
173 FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
174 
175 
176 ## Tensorflow 提供的模型下载
177 MODEL_NAME = 'xception_coco_voctrainval'
178 # ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
179 
180 _DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
181 _MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
182                'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
183                'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
184                'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }
185 
186 
187 _TARBALL_NAME = 'deeplab_model.tar.gz'
188 
189 # model_dir = tempfile.mkdtemp()
190 model_dir = './'
191 # tf.gfile.MakeDirs(model_dir)
192 
193 #
194 download_path = os.path.join(model_dir, _TARBALL_NAME)
195 print('downloading model, this might take a while...')
196 # urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
197 print('download completed! loading DeepLab model...')
198 
199 
200 
201 # model_dir = '/‘
202 
203 # download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
204 MODEL = DeepLabModel('./deeplab_model.tar.gz')
205 # MODEL = './deeplab_model.tar.gz'
206 print('model loaded successfully!')
207 
208 
209 ##
210 def run_visualization(imagefile):
211     """
212     DeepLab 语义分割,并可视化结果.
213     """
214     # orignal_im = Image.open(imagefile)
215     # print(type(orignal_im))
216     # orignal_im.show()
217     print('running deeplab on image %s...' % imagefile)
218     resized_im, seg_map = MODEL.run(Image.open(imagefile))
219 
220 
221     vis_segmentation(resized_im, seg_map,imagefile)
222 
223 images_dir = './pictures'
224 images = sorted(os.listdir(images_dir))
225 print(images)
226 # img='205729y9fodss9ao6ol5921-150x150.jpg'
227 # img.show()
228 for imgfile in images:
229 # img.show()
230     run_visualization(os.path.join(images_dir, imgfile))
231 
232 print('Done.')

所使用的是deeplab_model.tar.gz,也可以修改代码使用在标准数据集上预训练过的模型;代码在182行附近。

1.修改模型保存路径

2.修改图片路径

3.运行即可

参考自:https://www.aiuai.cn/aifarm252.html

猜你喜欢

转载自www.cnblogs.com/ywheunji/p/10541818.html
今日推荐