Cell Segmentation视频流实时输出

在这里插入图片描述

软件环境

  • opencv
  • cmake
  • vs2017
  • tensorRT
  • CUDA和cudnn
  • pycuda
  • pytorch

1. labelme样本制作

  • 安装labelme

pip install labelme

  • json转png
import json
import os
import PIL.Image
import yaml
from labelme import utils
import io
import skimage.io

import numpy as np
import PIL.Image
import PIL.ImageDraw
import glob
import shutil


def label_colormap(N=256):

    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    cmap = np.zeros((N, 3))
    for i in range(0, N):
        id = i
        r, g, b = 0, 0, 0
        for j in range(0, 8):
            r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
            g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
            b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
            id = (id >> 3)
        cmap[i, 0] = r
        cmap[i, 1] = g
        cmap[i, 2] = b
    cmap = cmap.astype(np.float32) / 255
    return cmap


# similar function as skimage.color.label2rgb
def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
    if n_labels is None:
        n_labels = len(np.unique(lbl))

    cmap = label_colormap(n_labels)
    cmap = (cmap * 255).astype(np.uint8)

    lbl_viz = cmap[lbl]
    lbl_viz[lbl == -1] = (0, 0, 0)  # unlabeled

    if img is not None:
        img_gray = PIL.Image.fromarray(img).convert('LA')
        img_gray = np.asarray(img_gray.convert('RGB'))
        # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
        lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
        lbl_viz = lbl_viz.astype(np.uint8)

    return lbl_viz


def draw_label(label, img=None, label_names=None, colormap=None):
    import matplotlib.pyplot as plt
    backend_org = plt.rcParams['backend']
    plt.switch_backend('agg')

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                        wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())

    if label_names is None:
        label_names = [str(l) for l in range(label.max() + 1)]

    if colormap is None:
        colormap = label_colormap(len(label_names))

    label_viz = label2rgb(label, img, n_labels=len(label_names))
    plt.imshow(label_viz)
    plt.axis('off')

    plt_handlers = []
    plt_titles = []
    for label_value, label_name in enumerate(label_names):
        if label_value not in label:
            continue
        if label_name.startswith('_'):
            continue
        fc = colormap[label_value]
        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
        plt_handlers.append(p)
        plt_titles.append('{value}: {name}'
                          .format(value=label_value, name=label_name))
    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)

    f = io.BytesIO()
    plt.savefig(f, bbox_inches='tight', pad_inches=0)
    plt.cla()
    plt.close()

    plt.switch_backend(backend_org)

    out_size = (label_viz.shape[1], label_viz.shape[0])
    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
    out = np.asarray(out)
    return out


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)



def DataGeneration(imgRoot, jsonRoot):
    imgList = glob.glob(f"{imgRoot}/*.png")
    labelList = glob.glob(f"{jsonRoot}/*.json")
    outRoot = "./data"
    mkdir(outRoot)
    outLblRoot = os.path.join(outRoot, "train_labels")
    outImgRoot = os.path.join(outRoot, "train_images")
    mkdir(outLblRoot)
    mkdir(outImgRoot)

    for labelPath, imgPath in zip(labelList, imgList):
        if os.path.isfile(labelPath):
            baseName = os.path.basename(labelPath).split(".")[0] + ".png"
            data = json.load(open(labelPath))
            img = utils.img_b64_to_arr(data['imageData'])
            lbl, lbl_names = utils.labelme_shapes_to_label(img.shape, data['shapes'])
            outName = os.path.join(outLblRoot, baseName)
            skimage.io.imsave(outName, np.array(lbl))
            shutil.copy(imgPath, os.path.join(outImgRoot, baseName))

2. Pytorch模型训练

在此代码的基础上,实现自己的dataset类、loss函数等

from torch import nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from dataset import BIPEDDataset
from loss import *
from config import Config
import segmentation_models_pytorch as smp
from cyclicLR import CyclicCosAnnealingLR, LearningRateWarmUP
import torchgeometry as tgm
import numpy as np
import time
import os
import cv2 as cv
import glob
from random import sample

from lookahead import Lookahead
import warnings
warnings.filterwarnings("ignore")


def weight_init(m):
    if isinstance(m, (nn.Conv2d, )):
        torch.nn.init.normal_(m.weight, mean=0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, mean=0.0,)
        if m.weight.data.shape == torch.Size([1,6,1,1]):
            torch.nn.init.constant_(m.weight,0.2)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    # for fusion layer
    if isinstance(m, (nn.ConvTranspose2d,)):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, std=0.1)

        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        #self.model = FishUnet(num_classes=2, in_channels=8, encoder_depth=34).to(self.device).apply(weight_init)
        # self.model = ExtremeC3Net(2).to(self.device)
        self.model = smp.Unet(encoder_name="resnet18",
                              in_channels=3,
                              classes=2).to(self.device)
        #self.model = HighResolutionNet(num_classes=3, in_chs=8).to(self.device)
        self.criterion_seg = WeightedFocalLoss2d()

        optimizer = torch.optim.AdamW([
                {
    
    'params': self.model.parameters()},
                # {'params': self.awl.parameters(), 'weight_decay': 0}
            ])
        self.optimizer = Lookahead(optimizer)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20, verbose=True)
        self.scheduler = LearningRateWarmUP(optimizer=optimizer, target_iteration=10, target_lr=0.0005,
                                            after_scheduler=scheduler)
        mkdir(cfg.model_output)

    def load_net(self, resume):
        self.model = torch.load(resume,  map_location=self.device)
        print('load pre-trained model successfully')

    def build_loader(self):
        imglist = glob.glob(f'{self.cfg.train_root}/*')
        indices = list(range(len(imglist)))
        indices = sample(indices, len(indices))
        split = int(np.floor(0.2 * len(imglist)))
        train_idx, valid_idx = indices, indices[:split]
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        print(f'Total images {len(imglist)}')
        print(f'No of train images {len(train_idx)}')
        print(f'No of validation images {len(valid_idx)}')

        train_dataset = BIPEDDataset(self.cfg.train_root, crop_size=self.cfg.img_width)
        valid_dataset = BIPEDDataset(self.cfg.train_root, crop_size=self.cfg.img_width)

        train_loader = DataLoader(train_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=False,
                                  sampler=train_sampler,
                                  drop_last=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=False,
                                  sampler=valid_sampler,
                                  drop_last=True)
        return train_loader, valid_loader

    def validation(self, epoch, dataloader):
        self.model.eval()
        running_loss = []
        for batch_id, sample_batched in enumerate(dataloader):

            images = sample_batched['image'].to(self.device)  # BxCxHxW
            labels_seg = sample_batched['gt'].to(self.device)  # BxHxW

            file_name = sample_batched['file_name']

            segments = self.model(images)

            loss_seg = self.criterion_seg(segments, labels_seg)

            loss = loss_seg

            print(time.ctime(), 'validation, Epoch: {0} Sample {1}/{2} Loss: {3}' \
                  .format(epoch, batch_id, len(dataloader), loss.item()), end='\r')

            self.save_image_bacth_to_disk(segments, file_name)
            running_loss.append(loss.detach().item())
            return np.mean(np.array(running_loss))

    def save_image_bacth_to_disk(self, tensor, file_names):
        output_dir = self.cfg.valid_output_dir
        mkdir(output_dir)
        assert len(tensor.shape) == 4, tensor.shape
        for tensor_image, file_name in zip(tensor, file_names):
            image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 1]
            image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8)  #
            output_file_name = os.path.join(output_dir, f"{file_name}.png")
            cv.imwrite(output_file_name, image_vis)

    def train(self):
        train_loader, valid_loader = self.build_loader()
        best_loss = 1000000
        best_train_loss = 1000000
        valid_losses = []
        train_losses = []

        running_loss = []
        for epoch in range(1, self.cfg.num_epochs):
            self.model.train()
            for batch_id, sample_batched in enumerate(train_loader):

                images = sample_batched['image'].to(self.device)  # BxCxHxW
                labels_seg = sample_batched['gt'].to(self.device)  # BxHxW


                segments = self.model(images)

                loss_seg = self.criterion_seg(segments, labels_seg)

                loss = loss_seg

                self.optimizer.zero_grad()
                torch.autograd.backward([loss_seg])
                # loss.backward()
                self.optimizer.step()
                print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}'\
                      .format(epoch, batch_id, len(train_loader), loss.item()), end='\r')
                running_loss.append(loss.detach().item())

            train_loss = np.mean(np.array(running_loss))

            valid_loss = self.validation(epoch, valid_loader)

            if epoch > 10:
                self.scheduler.after_scheduler.step(valid_loss)
            else:
                self.scheduler.step(epoch)

            lr = float(self.scheduler.after_scheduler.optimizer.param_groups[0]['lr'])

            if valid_loss < best_loss:
                torch.save(self.model, os.path.join(self.cfg.model_output, f'best.pth'))
                # modelList = glob.glob(os.path.join(self.cfg.model_output, f'epoch*_model.pth'))
                # if len(modelList) > 3:
                #     modelList = modelList[:-3]
                #     for modelPath in modelList:
                #         os.remove(modelPath)

                print(f'find optimal model, loss {best_loss}==>{valid_loss} \n')
                best_loss = valid_loss

                # print(f'lr {lr:.8f} \n')
                valid_losses.append([valid_loss, lr])
                np.savetxt(os.path.join(self.cfg.model_output, 'valid_loss.txt'), valid_losses, fmt='%.6f')


        # plt.ioff()
        # plt.show()


if __name__ == '__main__':
    import argparse
    import genDataset
    parser = argparse.ArgumentParser(
        description='''This is a code for training model.''')
    parser.add_argument('--imageRoot', type=str, default=r'D:\BaiduNetdiskDownload\data_3groups\GF3_Yangzhitang_Samples_Feature_sub5', help='path to the root of image')
    parser.add_argument('--jsonRoot', type=str,
                        default=r'D:\BaiduNetdiskDownload\data_3groups\GF3_Yangzhitang_Samples_Feature_sub5',
                        help='path to the root of data')
    parser.add_argument('--in_chs', type=int, default=3, help='input channels')
    parser.add_argument('--num_classes', type=int, default=2, help='the number of class')
    args = parser.parse_args()

    print('The training dataset is preparing... Please wait!')
    genDataset.DataGeneration(args.imageRoot, args.jsonRoot)
    config = Config()
    config.in_chs = args.in_chs
    config.num_classes = args.num_classes

    print("Everything is ok! It's time for training.")
    trainer = Trainer(config)
    trainer.train()

3. Torch2trt模型量化

实现32bit模型参数到16bit的转换

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

from loguru import logger
import torch
import tensorrt as trt
from torch2trt import torch2trt



@logger.catch
def main():
    model = torch.load("./checkpoints/best.pth", map_location='cpu')
    # ckpt = torch.load("resnet18-5c106cde.pth", map_location="cpu")
    # model.load_state_dict(ckpt)

    logger.info("loaded checkpoint done.")
    model.eval()
    model.cuda()
    x = torch.ones(1, 3, 1024, 1792).cuda()

    print("Torch2TensorRT")
    model_trt = torch2trt(
        model,
        [x],
        fp16_mode=True,
        log_level=trt.Logger.INFO,
        max_workspace_size=(1 << 32)
    )
    #torch.save(model_trt.state_dict(), "model_trt.pth")
    logger.info("Converted TensorRT model done.")

    y = model(x)
    y_trt = model_trt(x)

    # check the output against PyTorch
    print(f"difference: {torch.max(torch.abs(y - y_trt))}")

    engine_file = "./checkpoints/model_trt.engine"

    print("generate engine file")
    with open(engine_file, "wb") as f:
        f.write(model_trt.engine.serialize())

    logger.info("Converted TensorRT model engine file is saved for C++ inference.")


## python tools/trt.py -n yolox-s -c D:\MyWorkSpace\git\YOLOX-main\yolox_s.pth
if __name__ == "__main__":
    main()

4. c++模型推理

模型推理代码参考https://github.com/wang-xinyu/tensorrtx/tree/master/unet

5. 计算细胞面积占比以及数量

  • 统计细胞数量。调用cv::findcontours
cv::Mat Can_img;
cv::Canny(outimg * 255, Can_img, 100, 250);
vector<vector<cv::Point>> contours;
vector<cv::Vec4i> hierarchy;
cv::findContours(Can_img, contours, hierarchy, cv::RETR_TREE, cv::CHAIN_APPROX_SIMPLE, cv::Point());
int count_cell = contours.size();
  • 面积占比。
cv::Mat outimg(INPUT_H, INPUT_W, CV_8UC1);
int count_pixel = 0;
for (int row = 0; row < INPUT_H; ++row)
{
    
    
	uchar *uc_pixel = outimg.data + row * outimg.step;
	for (int col = 0; col < INPUT_W; ++col)
	{
    
    
				//uc_pixel[col] = (uchar)prob[row * INPUT_W + col];
				if (prob[row * INPUT_W + col] > prob[INPUT_H * INPUT_W + row * INPUT_W + col])
				{
    
    
					uc_pixel[col] = 0;
				}
				else
				{
    
    
					uc_pixel[col] = 1;
					count_pixel++;
				}
	}
}
int percent = (count_pixel * 100) / (INPUT_W * INPUT_H);

6. c++动态库发布

头文件

#pragma once
using namespace std;
#include <string>

//1、创建一个接口类
class SEG
{
    
    
public:
	virtual int GetCount() = 0;
	virtual int GetPercent() = 0;
	virtual int Prediction() = 0;
};


//2、创建一个导出函数
extern "C" _declspec(dllexport) SEG* GetSeg(std::string engine_file_path, std::string input_video_path);

7. 动态库调用

根目录
在这里插入图片描述

lib 目录,将上节产生的动态库复制到下面的目录中。

在这里插入图片描述

CmakeLists.txt

cmake_minimum_required(VERSION 3.2)
project(seglib)

add_definitions(-std=c++11)
option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)
 
# 需包含的库文件
link_directories(
   D:/2022/3/medicalSeg/code/dll_use/lib
)

#//生成可执行文件
add_executable(seglib main.cpp)
link_directories(${
    
    TRT_DIR}\\lib)           #7
 
#//链接库到可执行文件
target_link_libraries(seglib seg_lib)

main.cpp

#include <iostream>
#include <thread>
#include "seg_lib.h"
#pragma comment(lib, "seg_lib.lib")

int main(int arcgc, char** argv) {
    
    
	const std::string engine_file_path = argv[1];
	const std::string input_video_path = argv[2];

	SEG* MySeg=GetSeg(engine_file_path, input_video_path);
	MySeg->GetCount();
	MySeg->GetPercent();
	MySeg->Prediction();
	MySeg->GetCount();
	MySeg->GetPercent();


	std::cout << "done....." << std::endl;
	return 0;
};

运行结果
在这里插入图片描述

相关源码下载

c++ 源码

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/123550328
今日推荐