tensorflow入门教程(四十七)人体姿态检测(四)

#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#

------韦访 20190627

1、概述

上一讲我们将训练的代码跑起来了,这一讲开始真正的来分析源码了。看代码的时候要结合论文看,才能看懂。

2、计算网络的输入输出大小

打开train.py文件,从main函数开始看,

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training codes for Openpose using Tensorflow')
    parser.add_argument('--model', default='cmu', help='model name')
    parser.add_argument('--datapath', type=str, default='G:/tensorflow/post-estimation/datasets/annotations')
    parser.add_argument('--imgpath', type=str, default='G:/tensorflow/post-estimation/datasets/')
    parser.add_argument('--batchsize', type=int, default=16)
    parser.add_argument('--gpus', type=int, default=1)
    parser.add_argument('--max-epoch', type=int, default=600)
    parser.add_argument('--lr', type=str, default='0.001')
    parser.add_argument('--tag', type=str, default='test')
    parser.add_argument('--checkpoint', type=str, default='')

    parser.add_argument('--input-width', type=int, default=432)
    parser.add_argument('--input-height', type=int, default=368)
    parser.add_argument('--quant-delay', type=int, default=-1)
    args = parser.parse_args()

    modelpath = logpath = './models/train/'

    if args.gpus <= 0:
        raise Exception('gpus <= 0')

首先是一些默认的参数的设置,上一讲说过就不再说了,继续往下看。

# define input placeholder
# 设置图片的宽高,做数据增强时用到
set_network_input_wh(args.input_width, args.input_height)
scale = 4

if args.model in ['cmu', 'vgg'] or 'mobilenet' in args.model:
    # 因为CMU使用了VGG19网络,做了3次步长为2的max_pool操作,即缩小了8倍
    scale = 8

# 设置scale,做数据增强时用到
set_network_scale(scale)
output_w, output_h = args.input_width // scale, args.input_height // scale

set_network_input_whset_network_scale函数将我们参数里设置的输入的宽和高还有scale传到pose_augment模块,上一讲讲过了,如果model用的是cmu网络,因为cmu使用了VGG19的前10层网络,做了3次步长为2的maxpool操作,所以每一次maxpool操作后,网络缩小2倍,经过3次则共缩小8倍,所以最终scale=8。所以输出的宽和高就等于

output_w, output_h = args.input_width // scale, args.input_height // scale

3、定义占位符

继续往下看,

logger.info('define model+')
with tf.device(tf.DeviceSpec(device_type="CPU")):
    # 定义占位符
    # 输入图像 shape=(16, 368, 432, 3)
    input_node = tf.placeholder(tf.float32, shape=(args.batchsize, args.input_height, args.input_width, 3), name='image')
    # 向量图 shape=(16, 46, 54, 38)
    vectmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 38), name='vectmap')
    # 热图 shape=(16, 46, 54, 19)
    heatmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 19), name='heatmap')

    上面定义了3个占位符,分别是输入图像input_node向量图(论文中的肢干矢量图,我瞎翻译的)vectmap_node热图(论文中的关节点置信图,我瞎翻译的)heatmap_node向量图和热图的深度分别为38和19是因为论文就是这样的。输入的图片深度是3。

 

4、数据增强

继续往下看,

# prepare data
# 初始化数据
# args.datapath:annotations
# batchsize:batchsize
# imgpath:dataset
# 解析 person_keypoints_train2017.json 文件,将训练数据存到队列
df = get_dataflow_batch(args.datapath, True, args.batchsize, img_path=args.imgpath)
来看看get_dataflow_batch函数,
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
    logger.info('dataflow img_path=%s' % img_path)
    ds = get_dataflow(path, is_train, img_path=img_path)
    ds = BatchData(ds, batchsize)
    # if is_train:
    #     ds = PrefetchData(ds, 10, 2)
    # else:
    #     ds = PrefetchData(ds, 50, 2)

    return ds

先来看看get_dataflow函数,

def get_dataflow(path, is_train, img_path=None):
    # 要读取的json是person_keypoints_train2017.json,图片文件在train2017
    ds = CocoPose(path, img_path, is_train)       # read data from lmdb
    if is_train:
        ds = MapData(ds, read_image_url)
        # 数据增强处理
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        # augs = [
        #     imgaug.RandomApplyAug(imgaug.RandomChooseAug([
        #         imgaug.GaussianBlur(max_size=3)
        #     ]), 0.7)
        # ]
        # ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
    else:
        ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)
        # 第二个参数:size of the queue to hold prefetched datapoints.
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)

    return ds

先来看看CocoPose的__init__函数, 

def __init__(self, path, img_path=None, is_train=True, decode_img=True, only_idx=-1):
    self.is_train = is_train
    self.decode_img = decode_img
    self.only_idx = only_idx

    # 如果是训练,就取train2017里的数据,否则取val2017里的数据
    if is_train:
        whole_path = os.path.join(path, 'person_keypoints_train2017.json')
    else:
        whole_path = os.path.join(path, 'person_keypoints_val2017.json')
    # 完整的路径
    self.img_path = (img_path if img_path is not None else '') + ('train2017/' if is_train else 'val2017/')
    # COCO API
    self.coco = COCO(whole_path)

    logger.info('%s dataset %d' % (path, self.size()))

上面代码中,如果是训练的话,就解析person_keypoints_train2017.json文件,否则解析person_keypoints_val2017.json文件,再实例化COCO类。COCO类我没了解过,它在我们安装的pycocotools模块中,把它当成解析COCO数据集的工具就可以了,这不是我们现在分析的重点。

回到get_dataflow函数,篇幅有限,调重点来讲,read_image_url函数读取解析到的图片文件,pose_random_scale、pose_rotation、pose_flip、pose_resize_shortestedge_random、pose_crop_random函数都是做一些数据增强的操作,比如,随机缩放、随机旋转、随机裁剪等。这里要单独实现数据增强的函数,而不用官方的函数是因为我们在对图片进行数据增强的时候,对应的标签里的关键点坐标值也会发生变化,所以还得计算数据增强以后的关键点的坐标。这也不是我们分析的重点,重点来看pose_to_img函数,

def pose_to_img(meta_l):
    global _network_w, _network_h, _scale
    # print('wf====>>>>pose_to_img _scale:', _scale, ' _network_w:', _network_w, ' _network_h:', _network_h, 'os.getpid():', os.getpid(), 'os.getppid():', os.getppid())
    return [
        # 输入图像
        meta_l[0].img.astype(np.float16),
        # 生成热图
        meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)),
        # 生成矢量图
        meta_l[0].get_vectormap(target_size=(_network_w // _scale, _network_h // _scale))
    ]

先来看get_heatmap函数,target_size的宽高分别是输入图像的宽高除以_scale,对应我们论文的那么网络的话,这里的_scale=8,

# 生成热图
@jit
def get_heatmap(self, target_size):
    heatmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width), dtype=np.float32)
    # 解析每个人的关键点坐标
    for joints in self.joint_list:
        # 解析某个人的关键点
        for idx, point in enumerate(joints):
            # 如果有坐标是负数,则表示该关键点不存在
            if point[0] < 0 or point[1] < 0:
                continue
                
            CocoMetadata.put_heatmap(heatmap, idx, point, self.sigma)

    heatmap = heatmap.transpose((1, 2, 0))

    # background
    heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)
    # 缩放heatmap尺寸
    if target_size:
        heatmap = cv2.resize(heatmap, target_size, interpolation=cv2.INTER_AREA)

    return heatmap.astype(np.float16)

上面的代码中,self.joint_list包含一张图片中所有人的关键点坐标,可能包含多个人。将这些关键点存到self.joint_list的函数是CocoMetadata的__init__函数,

# idx: 下标
# img_url: 图片url
# img_meta: 当前 image 的信息,宽高,连接,id等
# anns: image的annotations信息,segmentation,关键点个数,关键点,image_id等
# sigma=8.0
def __init__(self, idx, img_url, img_meta, annotations, sigma):
    self.idx = idx
    self.img_url = img_url
    self.img = None
    self.sigma = sigma

    self.height = int(img_meta['height'])
    self.width = int(img_meta['width'])

    # 获取关节点的坐标,从keypoints中解析,一张图片中有可能有个人的关键点坐标
    joint_list = []
    # print('======================')
    for ann in annotations:
        # 没有关键点,跳过
        if ann.get('num_keypoints', 0) == 0:
            continue

        # 找到关键点
        kp = np.array(ann['keypoints'])
        # 从0开始,隔3个取,共17个
        xs = kp[0::3]
        # 从1开始,隔3个取
        ys = kp[1::3]
        # 从2开始,隔3个取
        # v 有3种状态, 0:未标注, 1:标注了,但是被遮挡了,2:标注了,且能看到
        vs = kp[2::3]
        # print('xs:', xs)
        # print('ys:', ys)
        # print('vs:', vs)
        # vs >=1, 表示对应的xs和ys是关键点,如果没有关键点,用(-1000, -1000)代替
        joint_list.append([(x, y) if v >= 1 else (-1000, -1000) for x, y, v in zip(xs, ys, vs)])

    # print('joint_list:', joint_list)
    # 这里放的就是关节的坐标,可能有多个人的
    self.joint_list = []

    # print('-----------------------')
    # (6, 7) 点是脊柱的坐标,由左肩(8, 8)和右肩(7, 7)数据生成的,数据集没有该点的标注
    # 因为coco数据集的关键点的index跟我们代码的不一致,所以这里相当于做个映射
    transform = list(zip(
        [1, 6, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4],
        [1, 7, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4]
    ))
    # print('transform:', transform)
    # 重新生成我们要的关键点的顺序和数据
    for prev_joint in joint_list:
        new_joint = []
        for idx1, idx2 in transform:
            # print('idx1:', idx1)
            # print('idx2:', idx2)

            j1 = prev_joint[idx1-1]
            j2 = prev_joint[idx2-1]
            # print('j1:', j1)
            # print('j2:', j2)
            if j1[0] <= 0 or j1[1] <= 0 or j2[0] <= 0 or j2[1] <= 0:
                new_joint.append((-1000, -1000))
            else:
                new_joint.append(((j1[0] + j2[0]) / 2, (j1[1] + j2[1]) / 2))

        # 第19个关节点,数据集里没有这个点的标注
        new_joint.append((-1000, -1000))
        self.joint_list.append(new_joint)
    # print('self.joint_list:', self.joint_list)

    # print('joint size=%d' % len(self.joint_list))

    # logger.debug('joint size=%d' % len(self.joint_list))

回到get_heatmap函数,接着看put_heatmap函数做了什么,

@staticmethod
@jit(nopython=True)
def put_heatmap(heatmap, plane_idx, center, sigma):
    # 关键点坐标
    center_x, center_y = center
    # 置信图的高和宽
    _, height, width = heatmap.shape[:3]

    th = 4.6052
    # 求平方根
    delta = math.sqrt(th * 2)

    # 以(center_x, center_y)为中点,(x0, y0)做左上角,(x1, y1)为右下角,组成一个方框,作为热图
    x0 = int(max(0, center_x - delta * sigma))
    y0 = int(max(0, center_y - delta * sigma))

    x1 = int(min(width, center_x + delta * sigma))
    y1 = int(min(height, center_y + delta * sigma))

    for y in range(y0, y1):
        for x in range(x0, x1):
            # 高斯核函数
            d = (x - center_x) ** 2 + (y - center_y) ** 2
            exp = d / 2.0 / sigma / sigma
            if exp > th:
                continue
            heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp))
            heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)

上面的函数就是我们解析论文那篇博客的第5点,不清楚的可以看下面的链接,

https://blog.csdn.net/rookie_wei/article/details/90705880

这样,就通过数据集生成了热图。接着看怎么生成矢量图,回到pose_to_img函数,看看get_vectormap函数做了什么,

@jit
def get_vectormap(self, target_size):
    vectormap = np.zeros((CocoMetadata.__coco_parts*2, self.height, self.width), dtype=np.float32)
    countmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width), dtype=np.int16)

    # 解析每个人
    for joints in self.joint_list:
        for plane_idx, (j_idx1, j_idx2) in enumerate(CocoMetadata.__coco_vecs):
            # __coco_vecs都减一了才对应我们要的关节
            j_idx1 -= 1
            j_idx2 -= 1

            # 起始关节
            center_from = joints[j_idx1]
            # 结束关节
            center_to = joints[j_idx2]

            if center_from[0] < -100 or center_from[1] < -100 or center_to[0] < -100 or center_to[1] < -100:
                continue

            CocoMetadata.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to)

    vectormap = vectormap.transpose((1, 2, 0))
    nonzeros = np.nonzero(countmap)
    for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]):
        if countmap[p][y][x] <= 0:
            continue
        # 除以在这个像素点上存在关节的人的数量,要不然就叠加了
        vectormap[y][x][p*2+0] /= countmap[p][y][x]
        vectormap[y][x][p*2+1] /= countmap[p][y][x]

    if target_size:
        vectormap = cv2.resize(vectormap, target_size, interpolation=cv2.INTER_AREA)

    return vectormap.astype(np.float16)

同样也是从self.joint_list获取每个人关键点的数据,重点来看put_vectormap函数,

@staticmethod
@jit(nopython=True)
def put_vectormap(vectormap, countmap, plane_idx, center_from, center_to, threshold=8):
    # 矢量图的高和宽
    _, height, width = vectormap.shape[:3]

    # 这个关节的矢量的x和y,即(vec_x, vec_y)就是该关节的矢量
    vec_x = center_to[0] - center_from[0]
    vec_y = center_to[1] - center_from[1]

    # 求关节矢量左上角坐标xy轴分别-threshold
    min_x = max(0, int(min(center_from[0], center_to[0]) - threshold))
    min_y = max(0, int(min(center_from[1], center_to[1]) - threshold))

    # 求关节矢量右下角坐标xy轴分别+threshold
    max_x = min(width, int(max(center_from[0], center_to[0]) + threshold))
    max_y = min(height, int(max(center_from[1], center_to[1]) + threshold))

    # 求关节矢量长度
    norm = math.sqrt(vec_x ** 2 + vec_y ** 2)

    # 长度为0就不管了
    if norm == 0:
        return

    # x轴方向的单位向量
    vec_x /= norm
    # y轴方向的单位向量
    vec_y /= norm
    # print('=============================')
    #
    for y in range(min_y, max_y):
        for x in range(min_x, max_x):
            # (bec_x, bec_y)为(x, y)到(center_from_x, center_from_y)组成的向量
            bec_x = x - center_from[0]
            bec_y = y - center_from[1]

            # 点 (x, y) 到 (vec_x, vec_y) 的垂直距离
            dist = abs(bec_x * vec_y - bec_y * vec_x)

            if dist > threshold:
                continue

            countmap[plane_idx][y][x] += 1

            # 保存该关节相对于水平方向的余弦值cos
            vectormap[plane_idx*2+0][y][x] = vec_x
            # 保存该关节相对于水平方向的正弦值sin
            vectormap[plane_idx*2+1][y][x] = vec_y
            # print('vec_x:', vec_x)
            # print('vec_y:', vec_y)

上面的函数做的就是我们解析论文的那篇博客的第6点,不明白的可以回去看论文。

6、可视化

接下来我们来可视化上面的代码得到的效果,得到解析论文博客中的第5和第6点的结果。利用pose_dataset.py最底部的main函数来实现,

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

    from pose_augment import set_network_input_wh, set_network_scale
    # set_network_input_wh(368, 368)
    set_network_input_wh(480, 320)
    set_network_scale(8)

    # df = get_dataflow('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')
    df = _get_dataflow_onlyread('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')
    # df = get_dataflow('/root/coco/annotations', False, img_path='http://gpu-twg.kakaocdn.net/braincloud/COCO/')

    from tensorpack.dataflow.common import TestDataSpeed
    TestDataSpeed(df).start()
    sys.exit(0)

    with tf.Session() as sess:
        df.reset_state()
        t1 = time.time()
        for idx, dp in enumerate(df.get_data()):
            if idx == 0:
                for d in dp:
                    logger.info('%d dp shape={}'.format(d.shape))
            print(time.time() - t1)
            t1 = time.time()
            CocoPose.display_image(dp[0], dp[1].astype(np.float32), dp[2].astype(np.float32))
            print(dp[1].shape, dp[2].shape)
            pass

    logger.info('done')

先将

df = _get_dataflow_onlyread('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')

改为

df = _get_dataflow_onlyread('G:/tensorflow/post-estimation/datasets/annotations', True, 'G:/tensorflow/post-estimation/datasets/')

这里的两个地址是我数据集存放的地址,不明白的,看上一篇博客,

https://blog.csdn.net/rookie_wei/article/details/93658329

再将

sys.exit(0)

删掉。

好了,先直接运行看能不能通过,执行下面的命令,

set PYTHONPATH=G:\tensorflow\post-estimation\tf-pose-estimation-master
python tf_pose\pose_dataset.py

这个G:\tensorflow\post-estimation\tf-pose-estimation-master是我项目的根目录,你们根据自己的情况改,运行结果,

from ._conv import register_converters as _register_converters

loading annotations into memory...

Done (t=9.33s)

creating index...

index created!

[2019-07-03 22:06:49,618] [pose_dataset] [INFO] G:/tensorflow/post-estimation/datasets/annotations dataset 118287

2019-07-03 22:06:49,618 INFO G:/tensorflow/post-estimation/datasets/annotations dataset 118287

  0%|                                                                                                                                                                                   |0/5000[00:00<?,?it/s]

Traceback (most recent call last):

  File "tf_pose\pose_dataset.py", line 558, in <module>

    TestDataSpeed(df).start()

  File "D:\Anaconda3\lib\site-packages\tensorpack\dataflow\common.py", line 56, in start

    for idx, dp in enumerate(itr):

  File "D:\Anaconda3\lib\site-packages\tensorpack\dataflow\common.py", line 292, in __iter__

    ret = self.func(copy(dp))  # shallow copy the list

  File "G:\tensorflow\post-estimation\tf-pose-estimation-master\tf_pose\pose_augment.py", line 273, in pose_to_img

    meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)),

cv2.error: OpenCV(4.1.0) C:\projects\opencv-python\opencv\modules\imgproc\src\resize.cpp:3555: error: (-215:Assertion failed) func != 0 && cn <= 4 in function 'cv::hal::resize'

果然我们的运气还是那么差,运行出错。来解决它,看提示,问题出在pose_to_img函数的get_heatmap函数的OpenCV的resize函数里,并且提示从TestDataSpeed(df).start()开始出的问题,好了,先看TestDataSpeed(df).start()的源码,

class TestDataSpeed(ProxyDataFlow):
    """ Test the speed of some DataFlow """
    def __init__(self, ds, size=5000, warmup=0):
        """
        Args:
            ds (DataFlow): the DataFlow to test.
            size (int): number of datapoints to fetch.
            warmup (int): warmup iterations
        """
        super(TestDataSpeed, self).__init__(ds)
        self.test_size = int(size)
        self.warmup = int(warmup)

    def __iter__(self):
        """ Will run testing at the beginning, then produce data normally. """
        self.start()
        for dp in self.ds:
            yield dp

    def start(self):
        """
        Start testing with a progress bar.
        """
        self.ds.reset_state()
        itr = self.ds.__iter__()
        if self.warmup:
            for _ in tqdm.trange(self.warmup, **get_tqdm_kwargs()):
                next(itr)
        # add smoothing for speed benchmark
        with get_tqdm(total=self.test_size,
                      leave=True, smoothing=0.2) as pbar:
            for idx, dp in enumerate(itr):
                pbar.update()
                if idx == self.test_size - 1:
                    break

没看出什么,那么,这个pose_to_img函数是怎么被调用的?TestDataSpeed(df).start()传入一个df的参数,那么这个df是怎么来的?往上看,

df = _get_dataflow_onlyread('G:/tensorflow/post-estimation/datasets/annotations', True, 'G:/tensorflow/post-estimation/datasets/')

Ok,看看这个_get_dataflow_onlyread函数做了什么?

def _get_dataflow_onlyread(path, is_train, img_path=None):
    ds = CocoPose(path, img_path, is_train)  # read data from lmdb
    ds = MapData(ds, read_image_url)
    ds = MapData(ds, pose_to_img)
    # ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
    return ds

pose_to_img就是在这里被调用的,这个问题还真不好找,因为我对MapData的用法不熟悉,看到_get_dataflow_onlyread函数的上面的get_dataflow函数,

def get_dataflow(path, is_train, img_path=None):
    # 要读取的json是person_keypoints_train2017.json,图片文件在train2017
    ds = CocoPose(path, img_path, is_train)       # read data from lmdb
    if is_train:
        ds = MapData(ds, read_image_url)
        # 数据增强处理
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        # augs = [
        #     imgaug.RandomApplyAug(imgaug.RandomChooseAug([
        #         imgaug.GaussianBlur(max_size=3)
        #     ]), 0.7)
        # ]
        # ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
    else:
        ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)
        # 第二个参数:size of the queue to hold prefetched datapoints.
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)

    return ds

这里在两个MapData中间还调用了MapDataComponent函数,对图片数据做一些数据增强处理,那么,我们模仿它的用法试试看,将

def _get_dataflow_onlyread(path, is_train, img_path=None):
    ds = CocoPose(path, img_path, is_train)  # read data from lmdb
    ds = MapData(ds, read_image_url)
    ds = MapData(ds, pose_to_img)
    # ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
    return ds

函数改为,

def _get_dataflow_onlyread(path, is_train, img_path=None):
    print('CocoPose-------------')
    ds = CocoPose(path, img_path, is_train)  # read data from lmdb
    print('CocoPose======')
    ds = MapData(ds, read_image_url)
    ds = MapDataComponent(ds, pose_random_scale)
    ds = MapDataComponent(ds, pose_rotation)
    ds = MapDataComponent(ds, pose_flip)
    ds = MapDataComponent(ds, pose_resize_shortestedge_random)
    ds = MapDataComponent(ds, pose_crop_random)
    print('MapData-------------')
    ds = MapData(ds, pose_to_img)
    print('MapData======')
    # ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
    return ds

再运行,

 

没问题了,接着去喝杯茶等待进度条到100%...

好不容易等到进度条到了100%,就打印了一堆数据,如下

0.09376311302185059

(46, 54, 19) (46, 54, 38)

0.26555824279785156

(46, 54, 19) (46, 54, 38)

0.2968161106109619

(46, 54, 19) (46, 54, 38)

0.1718287467956543

(46, 54, 19) (46, 54, 38)

上面的数据没意思啊,是在for循环里print打印的,我们来看看CocoPose.display_image函数做了什么?

@staticmethod
def display_image(inp, heatmap, vectmap, as_numpy=False):
    global mplset
    # if as_numpy and not mplset:
    #     import matplotlib as mpl
    #     mpl.use('Agg')
    mplset = True
    import matplotlib.pyplot as plt

    fig = plt.figure()
    a = fig.add_subplot(2, 2, 1)
    a.set_title('Image')
    plt.imshow(CocoPose.get_bgimg(inp))

    a = fig.add_subplot(2, 2, 2)
    a.set_title('Heatmap')
    plt.imshow(CocoPose.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5)
    tmp = np.amax(heatmap, axis=2)
    plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5)
    plt.colorbar()

    tmp2 = vectmap.transpose((2, 0, 1))
    tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0)
    tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0)

    a = fig.add_subplot(2, 2, 3)
    a.set_title('Vectormap-x')
    plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
    plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5)
    plt.colorbar()

    a = fig.add_subplot(2, 2, 4)
    a.set_title('Vectormap-y')
    plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
    plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5)
    plt.colorbar()

    if not as_numpy:
        plt.show()
    else:
        fig.canvas.draw()
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        fig.clear()
        plt.close()
        return data

上面的代码想用matplotlib将图片显示出来的,但是没有显示出来,这种情况还是经常见到的,我也不想折腾的了,最简单的方法,将图片保存到本地文件,再打开来看。将

if not as_numpy:
        plt.show()

改成

if not as_numpy:
    plt.show()
    plt.savefig('h.png')

再运行,运行结果,

 

如果您感觉本篇博客对您有帮助,请打开支付宝,领个红包支持一下,祝您扫到99元,谢谢~~

猜你喜欢

转载自blog.csdn.net/rookie_wei/article/details/94655693