ModNet マッティング アルゴリズムとカメラのリアルタイム マッティングの例

目次

1. ビデオマットにグリーンスクリーンを使用する理由

1. カメラの色の理由

2. カットアウト効果の理由

3. 経済的コスト

2. カットアウトの背景知識

1、トライマップ

2. カットアウトとは

3. マッティングアルゴリズムの分類

3. ディープイメージマッティングアルゴリズム

1. ネットワーク構成図

2. アルゴリズムの解釈

(1) エンコーダ・デコーダ段

(2) 精製段階

4. ModNet アルゴリズム: リアルタイムでの Trimap-Free ポートレート マッティング

1. ネットワーク構成図

2. アルゴリズムの解釈

5、ModNetマット練習


1. ビデオマットにグリーンスクリーンを使用する理由

1. カメラの色の理由

主流のカメラ センサーは RGB 3 チャンネルであるため、最も正確なマット化を実現するには、3 原色の元の色を使用するのが最適です。また、カメラのCMOSセンサーマトリックスのほとんどはベイヤー配列を採用しており、配列内に緑の感光点が2つあり、赤と青よりも高い位置にあるため、情報が豊富で除去しやすくなっています。

2. カットアウト効果の理由

動画内のキャラクターや肌のほとんどはコントラストの高い緑の補色で構成されているため、レンダリング処理時にコンピュータがエッジや毛の質感を識別しやすくなり、マット処理の負荷が軽減されます。

3. 経済的コスト

背景の緑色は輝度が高く、撮影時は輝度を下げて省電力にできます。

2. カットアウトの背景知識

ポートレートマット化: アルゴリズムの概要とエンジニアリングの実装 (1) - クラウド コミュニティ - HUAWEI CLOUD

1、トライマップ

最も一般的に使用される事前知識は 3 値マップであり、各ピクセルは {0, 128, 255} のいずれかであり、それぞれ前景、未知、背景を表します。

2. カットアウトとは

写真 I の場合、肖像画の関心のある部分は前景 F と呼ばれ、残りは背景 B と呼ばれます。その場合、画像 I は F と B の加重融合とみなすことができます。 I= alpha ∗
F + ( 1−alpha) B alpha形状I一致します。

マッティング タスクは、適切な重みアルファ マトリックスを見つけることです。

上記の式に従って前景画像と背景画像を融合するプロセスは、次のように例示されます。

写真の中央の円の部分が前景、残りが背景であるとします。上の 2 つの画像を式に従って結合すると、中央の円はすべて前景関連のピクセルとなり、円の外側はすべて背景関連のピクセルになります。アルファは前景画像の確率行列に対応します。

アルファトレーニングが完了したら、写真の切り抜きを完成させたい場合は、アルファ*元の写真+ (1-アルファ)*白い背景の写真だけが必要です。

アルファは [0, 1] の間の連続値であり、ポートレート セグメンテーションとは異なり、ピクセルが前景に属する確率として理解できます。ポートレート セグメンテーション タスクでは、アルファは 0 または 1 のみであり、これは本質的に分類タスクですが、マッティングは回帰タスクです。

画像切り取りタスクのグラウンド トゥルースでは、値が 0 と 1 の間に分布していることがわかります。

セマンティック セグメンテーションのグランド トゥルースでは、値が 0 または 1 のいずれかであることがわかります。

3. マッティングアルゴリズムの分類

現在普及しているマッティングアルゴリズムは大きく2つに分類できます。

1 つは事前情報を必要とする Trimap ベースの方法であり、広範な事前情報には Trimap、ラフマスク、無人背景画像、ポーズ情報などが含まれます。ネットワークは事前情報と画像情報を使用して共同でアルファを予測します

もう 1 つは、画像情報のみに基づいてアルファを予測する Trimap フリーの方法です。これは実用的なアプリケーションに適していますが、効果は一般に Trimap ベースの方法ほど良くありません。

現在の主流はトライマップフリーアルゴリズムです。

3. ディープイメージマッティングアルゴリズム

1. ネットワーク構成図

2. アルゴリズムの解釈

ネットワークには、エンコーダ/デコーダ ステージとリファインメント ステージが含まれます。

(1) エンコーダ・デコーダ段

入力は RGB イメージのパッチとトライマップに対応する concat であるため、4 チャネルが含まれており、エンコードおよびデコード後にシングル チャネルの生のアルファ pred が出力されます。この段階での損失は 2 つの部分で構成されます。

最初の部分は、予測されたアルファと実際のアルファの間の絶対誤差です。L1 損失が 0 で微分できないことを考慮して、シャルボニエ損失を使用して近似します。

2 番目の部分は、予測されたアルファ、実際の前景および実際の背景で構成される RGB イメージと、実際の RGB イメージ間の絶対誤差です。その機能は、ネットワークに制約を課すことであり、シャルボニエ損失も次の近似に使用されます。

最終的な損失は、次の 2 つの部分の加重合計です。

(2) 精製段階

その入力は、Encoder-Decoder ステージによって出力された生の alpha pred と元の RGB 画像 (これも 4 チャネル) の連結であり、元の RGB は調整のための境界詳細情報を提供できます。ポイントは、スキップ接続を使用して、Encoder-Decoder ステージによって出力された生の alpha pred と、Refinement ステージによって出力された洗練された alpha pred に対して加算演算を実行し、最終的な予測結果を出力することです。実際、リファインメント ステージは残差ブロックであり、境界情報は残差学習を通じてモデル化されます。これは、ノイズ除去モデルのノイズ モデリングとまったく同じです。

洗練段階では損失が 1 つだけあります。洗練されたアルファ pred と GT アルファ マットはシャルボニエ損失を計算します。

4. ModNet アルゴリズム: リアルタイムでの Trimap-Free ポートレート マッティング

1. ネットワーク構成図

2. アルゴリズムの解釈

ネットワーク構造は、意味推定ブランチ、詳細予測ブランチ、および意味-詳細融合ブランチで構成されます。

5、ModNetマット練習

参考記事:

[マット化] MODNet: リアルタイム ポートレートマット化モデル-onnx Python 導入_onnx モデルのダウンロード_Dudu Taicai ブログ-CSDN ブログ

オリジナル作成者の onnix モデルのリンク: https://download.csdn.net/download/qq_40035462/85046509

コード例:

import cv2
import time
from tqdm import tqdm
import numpy as np
import onnxruntime as rt


class Matting:
    def __init__(self, model_path='onnx_model\modnet.onnx', input_size=(512, 512)):
        self.model_path = model_path
        self.sess = rt.InferenceSession(self.model_path, providers=['CUDAExecutionProvider'])
        # self.sess = rt.InferenceSession(self.model_path)  # 默认使用cpu
        self.input_name = self.sess.get_inputs()[0].name
        self.label_name = self.sess.get_outputs()[0].name
        self.input_size = input_size
        self.txt_font = cv2.FONT_HERSHEY_PLAIN

    def normalize(self, im, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        im = im.astype(np.float32, copy=False) / 255.0
        im -= mean
        im /= std
        return im

    def resize(self, im, target_size=608, interp=cv2.INTER_LINEAR):
        if isinstance(target_size, list) or isinstance(target_size, tuple):
            w = target_size[0]
            h = target_size[1]
        else:
            w = target_size
            h = target_size
        im = cv2.resize(im, (w, h), interpolation=interp)
        return im

    def preprocess(self, image, target_size=(512, 512), interp=cv2.INTER_LINEAR):
        image = self.normalize(image)
        image = self.resize(image, target_size=target_size, interp=interp)
        image = np.transpose(image, [2, 0, 1])
        image = image[None, :, :, :]
        return image

    def predict_frame(self, bgr_image):
        assert len(bgr_image.shape) == 3, "Please input RGB image."
        raw_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
        h, w, c = raw_image.shape
        image = self.preprocess(raw_image, target_size=self.input_size)

        pred = self.sess.run(
            [self.label_name],
            {self.input_name: image.astype(np.float32)}
        )[0]
        pred = pred[0, 0]
        matte_np = self.resize(pred, target_size=(w, h), interp=cv2.INTER_NEAREST)
        matte_np = np.expand_dims(matte_np, axis=-1)
        return matte_np

    def predict_image(self, source_image_path, save_image_path):
        bgr_image = cv2.imread(source_image_path)
        assert len(bgr_image.shape) == 3, "Please input RGB image."
        matte_np = self.predict_frame(bgr_image)
        matting_frame = matte_np * bgr_image + (1 - matte_np) * np.full(bgr_image.shape, 255.0)
        matting_frame = matting_frame.astype('uint8')
        cv2.imwrite(save_image_path, matting_frame)

    def predict_camera(self):
        cap_video = cv2.VideoCapture(0)
        if not cap_video.isOpened():
            raise IOError("Error opening video stream or file.")
        beg = time.time()
        count = 0
        while cap_video.isOpened():
            ret, raw_frame = cap_video.read()
            if ret:
                count += 1
                matte_np = self.predict_frame(raw_frame)
                matting_frame = matte_np * raw_frame + (1 - matte_np) * np.full(raw_frame.shape, 255.0)
                matting_frame = matting_frame.astype('uint8')

                end = time.time()
                fps = round(count / (end - beg), 2)
                if count >= 50:
                    count = 0
                    beg = end

                cv2.putText(matting_frame, "fps: " + str(fps), (20, 20), self.txt_font, 2, (0, 0, 255), 1)

                cv2.imshow('Matting', matting_frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            else:
                break
        cap_video.release()
        cv2.destroyWindow()

    def check_video(self, src_path, dst_path):
        cap1 = cv2.VideoCapture(src_path)
        fps1 = int(cap1.get(cv2.CAP_PROP_FPS))
        number_frames1 = cap1.get(cv2.CAP_PROP_FRAME_COUNT)
        cap2 = cv2.VideoCapture(dst_path)
        fps2 = int(cap2.get(cv2.CAP_PROP_FPS))
        number_frames2 = cap2.get(cv2.CAP_PROP_FRAME_COUNT)
        assert fps1 == fps2 and number_frames1 == number_frames2, "fps or number of frames not equal."

    def predict_video(self, video_path, save_path, threshold=2e-7):
        # 使用odf策略
        time_beg = time.time()
        pre_t2 = None  # 前2步matte
        pre_t1 = None  # 前1步matte

        cap = cv2.VideoCapture(video_path)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        number_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        print("source video fps: {}, video resolution: {}, video frames: {}".format(fps, size, number_frames))
        videoWriter = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size)

        ret, frame = cap.read()
        with tqdm(range(int(number_frames))) as t:
            for c in t:
                matte_np = self.predict_frame(frame)
                if pre_t2 is None:
                    pre_t2 = matte_np
                elif pre_t1 is None:
                    pre_t1 = matte_np
                    # 第一帧写入
                    matting_frame = pre_t2 * frame + (1 - pre_t2) * np.full(frame.shape, 255.0)
                    videoWriter.write(matting_frame.astype('uint8'))
                else:
                    # odf
                    error_interval = np.mean(np.abs(pre_t2 - matte_np))
                    error_neigh = np.mean(np.abs(pre_t1 - pre_t2))
                    if error_interval < threshold < error_neigh:
                        pre_t1 = pre_t2

                    matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
                    videoWriter.write(matting_frame.astype('uint8'))
                    pre_t2 = pre_t1
                    pre_t1 = matte_np

                ret, frame = cap.read()
            # 最后一帧写入
            matting_frame = pre_t1 * frame + (1 - pre_t1) * np.full(frame.shape, 255.0)
            videoWriter.write(matting_frame.astype('uint8'))
            cap.release()
        print("video matting over, time consume: {}, fps: {}".format(time.time() - time_beg, number_frames / (time.time() - time_beg)))


if __name__ == '__main__':
    model = Matting(model_path='onnx_model\modnet.onnx', input_size=(512, 512))
    model.predict_camera()
    # model.predict_image('images\\1.jpeg', 'output\\1.png')
    # model.predict_image('images\\2.jpeg', 'output\\2.png')
    # model.predict_image('images\\3.jpeg', 'output\\3.png')
    # model.predict_image('images\\4.jpeg', 'output\\4.png')
    # model.predict_video("video\dance.avi", "output\dance_matting.avi")

コードに含まれる modnet.onnx ファイルについては、上部の添付ファイルを参照してください。 

おすすめ

転載: blog.csdn.net/benben044/article/details/131136506