mmpose tutorial

mmpose is an open source toolbox for pose analysis based on PyTorch open sourced by SenseTime, and is   a member of the OpenMMLab project. The main branch code currently supports  PyTorch version 1.5 and above, and there are currently official Chinese tutorials

Key features include:

  1. Support various tasks related to human pose analysis

    MMPose supports mainstream pose analysis tasks that are widely concerned by the current academic community: mainly including 2D multi-person pose estimation, 2D hand pose estimation, 2D face key point detection, 133 key point body pose estimation, 3D human body shape restoration, and clothing key points detection, animal key point detection, etc. For details, please refer to  the function demonstration .

  2. Higher Accuracy and Faster Speed

    MMPose reproduces a variety of the most advanced human body posture analysis models in academia, including two types of algorithms: "top-down" and "bottom-up". Compared with other mainstream code bases, MMPose has higher model accuracy and training speed. For details, please refer to  the benchmark test .

  3. Support for diverse datasets

    MMPose supports the preparation and construction of many mainstream datasets, such as COCO, MPII, etc. For details, please refer to Dataset Preparation .

  4. modular design

    MMPose decouples the unified human body posture analysis framework into different module components. By combining different module components, users can easily build a custom human body posture analysis model.

  5. Extensive unit tests and documentation

    MMPose provides detailed documentation, API interface descriptions, and comprehensive unit tests for community reference.

Install

conda create -n open-mmlab python=3.7 -y
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html

Environment requirements: Linux, python 3.6+, pytorch 1.5+, cuda 10.2+, gcc 5+, the latest version of mmcv, numpy, cv2, json_traicks, xtcocotools, optional

  • mmdet (to run pose demos)
  • mmtrack (to run pose tracking demos)
  • pyrender  (to run 3d mesh demos)
  • smplx (to run 3d mesh demos)
pip3 install -r requirements.txt
python3 setup.py develop

data preparation

It is recommended to soft link the data set to $MMPOSE/data. If your file structure is different from the following, you may need to adjust the corresponding path in the configuration file. Taking coco as an example, the structure is shown in the following figure:

mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
    │── coco
        │-- annotations
        │   │-- person_keypoints_train2017.json
        │   |-- person_keypoints_val2017.json
        │   |-- person_keypoints_test-dev-2017.json
        |-- person_detection_results
        |   |-- COCO_val2017_detections_AP_H_56_person.json
        |   |-- COCO_test-dev2017_detections_AP_H_609_person.json
        │-- train2017
        │   │-- 000000000009.jpg
        │   │-- 000000000025.jpg
        │   │-- 000000000030.jpg
        │   │-- ...
        `-- val2017
            │-- 000000000139.jpg
            │-- 000000000285.jpg
            │-- 000000000632.jpg
            │-- ...

quick start

Inference using pretrained models

We provide some test scripts to evaluate the performance of the pre-trained model on the COCO dataset, and provide some advanced APIs to migrate to other projects

You can test a dataset with the following command

# 单 GPU 测试
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--fuse-conv-bn] \
    [--eval ${EVAL_METRICS}] [--gpu_collect] [--tmpdir ${TMPDIR}] [--cfg-options ${CFG_OPTIONS}] \
    [--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}]

# 多 GPU 测试
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] \
    [--gpu-collect] [--tmpdir ${TMPDIR}] [--options ${OPTIONS}] [--average-clips ${AVG_TYPE}] \
    [--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}]

Please, where CONFIG_FILE is the configuration file, CHECKPOINT_FILE is the path of the trained model, EVAL_METRICS is the evaluation index, if you need to override the parameters of the configuration file, you can use CFG_OPTIONS

Suppose you have downloaded the trained model and put it in the checkpoints folder

1. Test ResNet50 under the COCO dataset (do not store test results as files), and verify  mAP the indicators

./tools/dist_test.sh configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/res50_coco_256x192.py  checkpoints/SOME_CHECKPOINT.pth 1 --eval mAP

2. Use 8 GPUs to test ResNet under the COCO dataset. Download model weights online and verify  mAP metrics

./tools/dist_test.sh configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/res50_coco_256x192.py  https://download.openmmlab.com/mmpose/top_down/resnet/res50_coco_256x192-ec54d7f3_20200709.pth 8 --eval mAP

In addition, we also provide a wealth of scripts, so that you can quickly run the demo. Below is a demonstration example of multi-person pose estimation, where we use human-annotated human body boxes as input.

${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
    --img-root ${IMG_ROOT} --json-file ${JSON_FILE} --out-img-root ${OUTPUT_DIR} \
    [--show --device ${GPU_ID}] [--kpt-thr ${KPT_SCORE_THR}]

For example:

python demo/top_down_img_demo.py configs/body/2D_Kpt_SV_RGB_Img/topdown_hm/coco/hrnet_w48_coco_256x192.py  https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth  --img-root tests/data/coco/ --json-file tests/data/coco/test_coco.json --out-img-root vis_results

If the mmdet detection result is used as input, we provide the code for mmdet to detect people

python demo/top_down_img_demo_with_mmdet.py \
    ${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
    ${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
    --img-root ${IMG_ROOT} --img ${IMG_FILE} \
    --out-img-root ${OUTPUT_DIR} \
    [--show --device ${GPU_ID}] \
    [--bbox-thr ${BBOX_SCORE_THR} --kpt-thr ${KPT_SCORE_THR}]

For example:

python demo/top_down_img_demo_with_mmdet.py mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
    mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
    configs/top_down/hrnet/coco/hrnet_w48_coco_256x192.py \
    hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth \
    --img-root tests/data/coco/ \
    --img 000000196141.jpg \
    --out-img-root vis_results

training model

useMMDistributedDataParallel和MMDataParallel实现了分布式和单机训练

所有的输出都放在work_dir, this can be specified in the configuration file

By default, an evaluation will be done after each epoch, you can change this behavior in the configuration file evaluation = dict(interval=5) # This evaluate the model per 5 epoch.

According to the linear scaling principle , you need to adjust the learning rate according to the number of GPUs used and the batch size, lr=0.01 for 4 GPUs * 2 video/gpu and lr=0.08 for 16 GPUs * 4 video/gpu.

  • Single GPU training: python tools/train.py ${CONFIG_FILE} [optional arguments]
  • Multi-GPU training: ./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
  • Multi-machine training: ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}

Multiple tasks from a single machine:

CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4

If you use slurm to submit tasks, you need to modify the configuration file to use a different port

config1.py: dist_params = dict(backend='nccl', port=29500)

config2.py: dist_params = dict(backend='nccl', port=29501)

Evaluation: Use python tools/benchmark_inference.py ${MMPOSE_CONFIG_FILE} to get the speed of only network fronthaul excluding IO and preprocessing

See tutorials:  finetune modeladd new datasetadd new modules .

Visualization of training data is a very important part of debugging the network, but unfortunately no code implementation is given, here is

import math
import cv2
import mmcv
import argparse
import torchvision
import numpy as np
from tqdm import tqdm
from mmpose.datasets import build_dataloader, build_dataset
 
def parse_args():
    parser = argparse.ArgumentParser(description='mmpose test model')
    parser.add_argument('--config', default='configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrnet_w48_coco_256x192.py')
    args = parser.parse_args()
    return args
 
def main():
    args = parse_args()
    cfg = mmcv.Config.fromfile(args.config)
    dataset = build_dataset(cfg.data.train)
    skeltons = [0,0,0,1,2,0,0,5,6,7,8,0,0,11,12,13,14]
    data_loader = build_dataloader(dataset,samples_per_gpu=48, workers_per_gpu=0, dist=False, shuffle=False)
    # mean = np.array(cfg.train_pipeline[6]['mean'])
    # std = np.array(cfg.train_pipeline[6]['std'])
    for line in cfg.train_pipeline:
        if 'mean' in line:
            mean = np.array(line['mean'])
            std = np.array(line['std'])
    nrow = 8
    padding = 2
    for i, data_batch in enumerate(tqdm(data_loader)):
        batch_image = data_batch["img"].data
        metas = data_batch["img_metas"].data[0]
        file_name = "work_dirs/gt/"+str(i)+".jpg"
        grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
        ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
        ndarr = ndarr.copy()
        nmaps = batch_image.size(0)
        xmaps = min(nrow, nmaps)
        ymaps = int(math.ceil(float(nmaps) / xmaps))
        height = int(batch_image.size(2) + padding)
        width = int(batch_image.size(3) + padding)
        k = 0
        for y in range(ymaps):
            for x in range(xmaps):
                if k >= nmaps:
                    break
                joints = metas[k]["joints_3d"]
                joints_vis = metas[k]['joints_3d_visible']
                for j,joint in enumerate(joints):
                    joint_vis = joints_vis[j]
                    jx = int(x * width + padding + joint[0])
                    jy = int(y * height + padding + joint[1])
                    joint = joint.astype(np.int32)
                    if joint_vis[0]:
                        cv2.circle(ndarr, (jx, jy), 2, [0, 255, 0], 2)
                        cv2.putText(ndarr, str(j),(jx, jy),1,1,(0,0,255))
                    sp = skeltons[j]
                    jp_vis = joints_vis[sp]
                    if joint_vis[0] and jp_vis[0]:
                        jp = joints[sp]
                        jpx = int(x * width + padding + jp[0])
                        jpy = int(y * height + padding + jp[1])
                        cv2.line(ndarr, (jx,jy),(jpx ,jpy), (255,0,255))
                k = k + 1
        cv2.imwrite(file_name, ndarr)
        
if __name__=="__main__":
    main()

 

 How to find samples that detect errors and analyze the cause of errors is an important means to improve the performance of the algorithm. The code is as follows

import os
import cv2
import shutil
import numpy as np
import warnings
from tqdm import tqdm
from argparse import ArgumentParser
from xtcocotools.coco import COCO
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, vis_pose_result)
from mmpose.datasets import DatasetInfo

def get_args():
    parser = ArgumentParser()
    parser.add_argument('--pose_config', default="configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrnet_w48_coco_256x192.py")
    parser.add_argument('--pose_checkpoint', default='https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth')
    parser.add_argument('--img-root', default='data/coco/val2017')
    parser.add_argument('--json-file', default='data/coco/annotations/person_keypoints_val2017.json')
    parser.add_argument('--out-img-root', default='vis_coco',)
    parser.add_argument('--device', default='cuda:0')
    parser.add_argument('--kpt-thr', type=float, default=0.3)
    args = parser.parse_args()
    return args

def main():
    args = get_args()
    coco = COCO(args.json_file)
    # build the pose model from a config file and a checkpoint file
    pose_model = init_pose_model(
        args.pose_config, args.pose_checkpoint, device=args.device.lower())

    dataset = pose_model.cfg.data['test']['type']
    dataset_info = pose_model.cfg.data['test'].get('dataset_info', None)
    if dataset_info is None:
        warnings.warn(
            'Please set `dataset_info` in the config.'
            'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
            DeprecationWarning)
    else:
        dataset_info = DatasetInfo(dataset_info)

    img_keys = list(coco.imgs.keys())

    # optional
    return_heatmap = False
    skeleton = dataset_info.skeleton
    pose_link_color = dataset_info.pose_link_color
    sigmas = dataset_info.sigmas*100
    # e.g. use ('backbone', ) to return backbone feature
    output_layer_names = None

    error_sts = np.zeros((17),dtype=np.float64)
    error_cnts = np.zeros((17),dtype=np.int)
    # process each image
    for i in tqdm(range(len(img_keys))):#
        # get bounding box annotations
        image_id = img_keys[i]
        image = coco.loadImgs(image_id)[0]
        image_name = os.path.join(args.img_root, image['file_name'])
        ann_ids = coco.getAnnIds(image_id)

        # make person bounding boxes
        person_results = []
        keypoints = []
        for ann_id in ann_ids:
            person = {}
            ann = coco.anns[ann_id]
            # bbox format is 'xywh'
            person['bbox'] = ann['bbox']
            person_results.append(person)
            keypoints.append(ann['keypoints'])
        if len(person_results) == 0:
            #print("empty img", image_name)
            continue
        # test a single image, with a list of bboxes
        pose_results, returned_outputs = inference_top_down_pose_model(
            pose_model,
            image_name,
            person_results,
            bbox_thr=None,
            format='xywh',
            dataset=dataset,
            dataset_info=dataset_info,
            return_heatmap=return_heatmap,
            outputs=output_layer_names)
        img = cv2.imread(image_name)
        img_h, img_w, _ = img.shape
        valid_instance = 0
        for i in range(len(keypoints)):
            show = img#.copy()
            error_num = 0
            preds = pose_results[i]['keypoints']
            pts = np.array(keypoints[i]).reshape(-1,3)
            valid_num = np.sum(pts[:,2] == 2)
            if valid_num < 10:
                continue
            valid_instance += 1
            error = np.linalg.norm(pts[:,:2] - preds[:,:2],axis=1)
            error[pts[:,2] < 2] = 0
            for sk_id, sk in enumerate(skeleton):
                pos1 = (int(pts[sk[0], 0]), int(pts[sk[0], 1]))
                pos2 = (int(pts[sk[1], 0]), int(pts[sk[1], 1]))
                if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
                        and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w
                        and pos2[1] > 0 and pos2[1] < img_h
                        and pts[sk[0], 2] > args.kpt_thr
                        and pts[sk[1], 2] > args.kpt_thr):
                    r, g, b = pose_link_color[sk_id]
                    cv2.line(show, pos1, pos2, (int(r),int(g),int(b)))
            for j in range(len(pts)):
                pt = pts[j]
                pred = preds[j]
                if pt[2] > args.kpt_thr:
                    cv2.circle(show,(pt[0], pt[1]),3,(0,255,0),-1)
                if error[j] > sigmas[j] * 2:
                    error_num += 1
                    color = (255,0,255)
                else:
                    color = (255,0,0)
                cv2.circle(show,(int(pred[0]), int(pred[1])),2,color)
            error_sts += error
            error_cnts += error > 0
            if valid_num > 0:
                error_sum = np.sum(error) / valid_num
            else:
                error_sum = np.sum(error)
        dstname = image_name.split('/')[-1].split('.')[0]+".jpg"
        out_file = os.path.join(args.out_img_root, dstname)
        if valid_instance > 0:
            cv2.imwrite(out_file, show)
            if error_num > 3:
                print("error ", dstname, valid_num, error_sum, error_num)
                if os.path.exists(out_file):
                    shutil.move(out_file,'error/'+dstname)          
    for i,st in enumerate(error_sts):
        print(i, st/error_cnts[i])

if __name__ == '__main__':
    main()

  

 It is not difficult to see that occlusion, sideways and truncation are the main factors restricting the performance of the algorithm.

How to analyze the calculation is a key step in model optimization. mmpose provides get_flops to analyze the parameter size and FLOPs of the network, but the output is a bit complicated and the key points cannot be grasped.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,8,3)
        self.conv2 = nn.Conv2d(8,16,3)
        self.conv3 = nn.Conv2d(16,32,3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

def mmcv_flops(model, input_shape):
    from mmcv.cnn import get_model_complexity_info
    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print(f'{split_line}\nInput shape: {input_shape}\n'f'Flops: {flops}\nParams: {params}\n{split_line}')

def mrcv_flops(model, input_shape):
    from pt_flops import get_model_summary
    print(get_model_summary(model,input_shape))

if __name__=="__main__":
    #model = Net()
    from mmpose.models.backbones.alexnet import AlexNet
    model = AlexNet()
    from mmpose.models.backbones.hrnet import HRNet

    extra = dict(
    stage1=dict(
        num_modules=1,
        num_branches=1,
        block='BOTTLENECK',
        num_blocks=(4, ),
        num_channels=(64, )),
    stage2=dict(
        num_modules=1,
        num_branches=2,
        block='BASIC',
        num_blocks=(4, 4),
        num_channels=(32, 64)),
    stage3=dict(
        num_modules=4,
        num_branches=3,
        block='BASIC',
        num_blocks=(4, 4, 4),
        num_channels=(32, 64, 128)),
    stage4=dict(
        num_modules=3,
        num_branches=4,
        block='BASIC',
        num_blocks=(4, 4, 4, 4),
        num_channels=(32, 64, 128, 256)))
    model = HRNet(extra)
    input_shape = (3, 256, 192)
    mmcv_flops(model, input_shape)
    mrcv_flops(model, input_shape)

 The output is as follows, note that the G here is not at 1024 as usual, but in units of 1000, pay attention to the difference when comparing

==============================
Input shape: (3, 256, 192)
Flops: 7.7 GFLOPs
Params: 28.54 M
==============================
Model Summary
Name          Input Size        Output Size       Parameters        Multiply Adds (Flops)     
--------------------------------------------------------------------------------------------------------
Conv2d_1           [1, 3, 256, 192]  [1, 64, 128, 96]  1728      0.006      21233664  0.276
BatchNorm2d_1      [1, 64, 128, 96]  [1, 64, 128, 96]  128       0.000       1572864   0.020
Conv2d_2           [1, 64, 128, 96]  [1, 64, 64, 48]   36864     0.129     113246208 1.474
BatchNorm2d_2      [1, 64, 64, 48]   [1, 64, 64, 48]   128       0.000       393216    0.005
Conv2d_3           [1, 64, 64, 48]   [1, 64, 64, 48]   4096      0.014      12582912  0.164
BatchNorm2d_3      [1, 64, 64, 48]   [1, 64, 64, 48]   128       0.000       393216    0.005
Conv2d_4           [1, 64, 64, 48]   [1, 64, 64, 48]   36864     0.129     113246208 1.474
BatchNorm2d_4      [1, 64, 64, 48]   [1, 64, 64, 48]   128       0.000       393216    0.005
Conv2d_5           [1, 64, 64, 48]   [1, 256, 64, 48]  16384     0.057     50331648  0.655
BatchNorm2d_5      [1, 256, 64, 48]  [1, 256, 64, 48]  512       0.002       1572864   0.020
Conv2d_6           [1, 64, 64, 48]   [1, 256, 64, 48]  16384     0.057     50331648  0.655
BatchNorm2d_6      [1, 256, 64, 48]  [1, 256, 64, 48]  512       0.002       1572864   0.020
Conv2d_7           [1, 256, 64, 48]  [1, 64, 64, 48]   16384     0.057     50331648  0.655
BatchNorm2d_7      [1, 64, 64, 48]   [1, 64, 64, 48]   128       0.000       393216    0.005
Conv2d_8           [1, 64, 64, 48]   [1, 64, 64, 48]   36864     0.129     113246208 1.474
BatchNorm2d_8      [1, 64, 64, 48]   [1, 64, 64, 48]   128       0.000       393216    0.005
Conv2d_9           [1, 64, 64, 48]   [1, 256, 64, 48]  16384     0.057     50331648  0.655
BatchNorm2d_9      [1, 256, 64, 48]  [1, 256, 64, 48]  512       0.002       1572864   0.020
...
Conv2d_280         [1, 128, 16, 12]  [1, 128, 16, 12]  147456    0.517    28311552  0.369
BatchNorm2d_280    [1, 128, 16, 12]  [1, 128, 16, 12]  256       0.001       49152     0.001
Conv2d_281         [1, 128, 16, 12]  [1, 128, 16, 12]  147456    0.517    28311552  0.369
BatchNorm2d_281    [1, 128, 16, 12]  [1, 128, 16, 12]  256       0.001       49152     0.001
Conv2d_282         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_282    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_283         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_283    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_284         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_284    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_285         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_285    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_286         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_286    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_287         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_287    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_288         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_288    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_289         [1, 256, 8, 6]    [1, 256, 8, 6]    589824    2.067    28311552  0.369
BatchNorm2d_289    [1, 256, 8, 6]    [1, 256, 8, 6]    512       0.002       24576     0.000
Conv2d_290         [1, 64, 32, 24]   [1, 32, 32, 24]   2048      0.007      1572864   0.020
BatchNorm2d_290    [1, 32, 32, 24]   [1, 32, 32, 24]   64        0.000        49152     0.001
Conv2d_291         [1, 128, 16, 12]  [1, 32, 16, 12]   4096      0.014      786432    0.010
BatchNorm2d_291    [1, 32, 16, 12]   [1, 32, 16, 12]   64        0.000        12288     0.000
Conv2d_292         [1, 256, 8, 6]    [1, 32, 8, 6]     8192      0.029      393216    0.005
BatchNorm2d_292    [1, 32, 8, 6]     [1, 32, 8, 6]     64        0.000        3072      0.000
--------------------------------------------------------------------------------------------------------
Total Parameters: 27.214M (28,535,552)    Total Multiply Adds: 7.154GFLOPs (7681459200)
Number of Layers
Conv2d : 292 layers   BatchNorm2d : 292 layers   ReLU : 261 layers   Bottleneck : 4 layers   BasicBlock : 104 layers   Upsample : 28 layers   HRModule : 8 layers   

Guess you like

Origin blog.csdn.net/minstyrain/article/details/109363380