Segment-and-Track Anything - general intelligent video segmentation, tracking, editing algorithm interpretation and source code deployment

1. Division of all things

As Meta released the Segment Anything Model paper and open sourced the related algorithms, we can see that SAM is similar to GPT-4. The goal of this paper is to segment everything (zero samples) and convert natural language into The prompting paradigm of processing (NLP) was introduced into the field of computer vision (CV), providing broader support and in-depth research opportunities for CV-based models.
Segment Anything has two big differences from traditional image segmentation:

1. Data collection and active learning methods.

For a huge data set, such as one billion sets of data, it is almost infeasible to label all the data. Therefore, one solution is to adopt an active learning approach. This method can be divided into the following steps:
Preliminary annotation : First, a part of the dataset is manually annotated. This can be a small sample, but should cover a variety of situations and categories to ensure the model gets enough diversity.
Semi-supervised learning : Next, an initial model is trained using the labeled data. This model can be used to predict labels for unlabeled data.
Manual verification and correction : The prediction labels generated by the model need to be manually verified and corrected to ensure their accuracy. This can be done by professionals or crowdsourcing.
Iterative loop : Repeat the above steps to gradually expand the number of labeled data. Each iteration improves the model's performance because it can be trained on more data.
In this way, the annotation quality of the dataset can be gradually improved without the need to manually annotate all the data. When the data set is large enough and the model is trained to a certain extent, its performance will improve significantly.

2. prompt

Segment Anything introduces the concept of prompt. A prompt is a user-entered prompt that guides the model to generate a specific type of response. This is very useful in models like GPT-3 and SAM. Users can provide a question or description to help the model understand its intent and generate relevant answers or actions.
For example, in SAM, you can enter a prompt word such as "Cat" or "Dog" to tell the model that you want it to segment the cat or dog in the photo. The model will automatically detect and draw boxes to achieve segmentation. This prompt word can be used to limit the task of the model to focus it more on specific information extraction or operations.
Insert image description here
Both concepts are very important engineering efforts in handling large-scale data and improving model performance. Through reasonable data collection and active learning strategies, as well as through prompts that guide the model, user needs can be better met, the effect of the model can be improved, and the performance of the model can be gradually improved.

二、​Segment-and-Track Anything

1. Introduction to algorithm

The emergence of SAM has unified many applications of the segmentation task and also shows that there may be the potential for large-scale models in the field of CV. This breakthrough will definitely bring significant changes to the research in the field of CV, and many tasks will be handled uniformly. This new data set and paradigm combines super zero-shot generalization capabilities and will have a profound impact on the CV field. But it lacks support for video data. Subsequently, researchers from Zhejiang University's ReLER Laboratory unlocked SAM's video segmentation capabilities in the latest open source SAM-Track project, namely: Segment-and-track anything (SAM-track). SAM-Track can support target segmentation and tracking in various spatio-temporal scenes on a single card, including street view, AR, cells, animation, aerial photography, etc. It can track more than 200 objects at the same time, providing users with powerful video editing capabilities. . “Segment and Track Anything” utilizes automated and interactive methods. The main algorithms used include SAM (Segment Anything Models) for automatic/interactive keyframe segmentation, and DeAOT (Decoupling features in Associating Objects with Transformers) (NeurIPS2022) for efficient multi-object tracking and propagation. The SAM-Track pipeline implements SAM's dynamic automatic detection and segmentation of new objects, while DeAOT is responsible for tracking all recognized objects.

2. Project characteristics

Automatic/interactive segmentation : The SAM (Segment Anything Models) algorithm in the project provides automatic and interactive keyframe segmentation functions. With SAM, users can choose to use automatic segmentation algorithms or interact with the algorithms to achieve precise segmentation of any object in the video. This flexibility makes the project suitable for application scenarios with different needs.

Efficient multi-target tracking : Segment-and-Track-Anything also introduces the DeAOT algorithm to achieve efficient multi-target tracking and propagation. DeAOT leverages advanced tracking technology to accurately track multiple objects in a video and supports propagation and correlation between objects. This makes the project perform well when dealing with complex scenes and multi-target tracking tasks.

Independent and open : The project is an independent open source project that can be directly accessed and used. It provides rich documentation and sample code to help users get started quickly and perform customized development. At the same time, the project welcomes contributions and expansions from the community, which enables users to share experiences and results with other researchers and developers.

Broad application : Segment-and-Track-Anything’s segmentation and tracking functions can be applied to various video analysis tasks, including video surveillance, intelligent transportation, behavioral analysis, etc. It provides researchers and developers with a powerful tool for processing and analyzing video data with complex dynamic scenes.
Insert image description here

3. Project deployment

Project address: https://github.com/zx-yang/Segment-and-Track-Anything

1. Deployment environment

The system I tested and deployed here is win 10, cuda 11.8, cudnn 8.5, the GPU is RTX 3060, 8G video memory, and conda is used to create a virtual environment.
Create and activate a virtual environment:

conda create -n sta python==3.10
activate sta

Download project:

git clone https://github.com/z-x-yang/Segment-and-Track-Anything.git
cd Segment-and-Track-Anything
pip install gradio
pip install scikit-image

Because you need to use the GPU, install pytorch separately here :

conda install pytorch2.0.0 torchvision0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia

Because project dependencies need to be installed using sh scripts, bash is not supported under win, so m2-base needs to be installed separately:

conda install m2-base

Install project dependencies:

bash script/install.sh

When the following prompt appears, the installation is successful.
Insert image description here
GroundingDINO may not be installed successfully. You can install it directly from the source code:

git clone https://github.com/IDEA-Research/GroundingDINO.git
cd GroundingDINO/
pip install -e .
cd …

Download the required model:

bash script/download_ckpt.sh

If the model download is unsuccessful, you can also manually copy this address and download the model to the specified directory.

2. Run the project

python app.py

Then open http://127.0.0.1:7860/
Insert image description here
to import a video, and then only track one of the people. The effect is as follows:
Insert image description here

Video target tracking:

Target segmentation and target tracking

3. Segmentation and tracking processing code

import sys
sys.path.append("..")
sys.path.append("./sam")
from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from aot_tracker import get_aot
import numpy as np
from tool.segmentor import Segmentor
from tool.detector import Detector
from tool.transfer_tools import draw_outline, draw_points
import cv2
from seg_track_anything import draw_mask


class SegTracker():
    def __init__(self,segtracker_args, sam_args, aot_args) -> None:
        """
         Initialize SAM and AOT.
        """
        self.sam = Segmentor(sam_args)
        self.tracker = get_aot(aot_args)
        self.detector = Detector(self.sam.device)
        self.sam_gap = segtracker_args['sam_gap']
        self.min_area = segtracker_args['min_area']
        self.max_obj_num = segtracker_args['max_obj_num']
        self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
        self.reference_objs_list = []
        self.object_idx = 1
        self.curr_idx = 1
        self.origin_merged_mask = None  # init by segment-everything or update
        self.first_frame_mask = None

        # debug
        self.everything_points = []
        self.everything_labels = []
        print("SegTracker has been initialized")

    def seg(self,frame):
        '''
        Arguments:
            frame: numpy array (h,w,3)
        Return:
            origin_merged_mask: numpy array (h,w)
        '''
        frame = frame[:, :, ::-1]
        anns = self.sam.everything_generator.generate(frame)

        # anns is a list recording all predictions in an image
        if len(anns) == 0:
            return
        # merge all predictions into one mask (h,w)
        # note that the merged mask may lost some objects due to the overlapping
        self.origin_merged_mask = np.zeros(anns[0]['segmentation'].shape,dtype=np.uint8)
        idx = 1
        for ann in anns:
            if ann['area'] > self.min_area:
                m = ann['segmentation']
                self.origin_merged_mask[m==1] = idx
                idx += 1
                self.everything_points.append(ann["point_coords"][0])
                self.everything_labels.append(1)

        obj_ids = np.unique(self.origin_merged_mask)
        obj_ids = obj_ids[obj_ids!=0]

        self.object_idx = 1
        for id in obj_ids:
            if np.sum(self.origin_merged_mask==id) < self.min_area or self.object_idx > self.max_obj_num:
                self.origin_merged_mask[self.origin_merged_mask==id] = 0
            else:
                self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx
                self.object_idx += 1

        self.first_frame_mask = self.origin_merged_mask
        return self.origin_merged_mask

    def update_origin_merged_mask(self, updated_merged_mask):
        self.origin_merged_mask = updated_merged_mask
        # obj_ids = np.unique(updated_merged_mask)
        # obj_ids = obj_ids[obj_ids!=0]
        # self.object_idx = int(max(obj_ids)) + 1

    def reset_origin_merged_mask(self, mask, id):
        self.origin_merged_mask = mask
        self.curr_idx = id

    def add_reference(self,frame,mask,frame_step=0):
        '''
        Add objects in a mask for tracking.
        Arguments:
            frame: numpy array (h,w,3)
            mask: numpy array (h,w)
        '''
        self.reference_objs_list.append(np.unique(mask))
        self.curr_idx = self.get_obj_num()
        self.tracker.add_reference_frame(frame,mask, self.curr_idx, frame_step)

    def track(self,frame,update_memory=False):
        '''
        Track all known objects.
        Arguments:
            frame: numpy array (h,w,3)
        Return:
            origin_merged_mask: numpy array (h,w)
        '''
        pred_mask = self.tracker.track(frame)
        if update_memory:
            self.tracker.update_memory(pred_mask)
        return pred_mask.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.uint8)
    
    def get_tracking_objs(self):
        objs = set()
        for ref in self.reference_objs_list:
            objs.update(set(ref))
        objs = list(sorted(list(objs)))
        objs = [i for i in objs if i!=0]
        return objs
    
    def get_obj_num(self):
        objs = self.get_tracking_objs()
        if len(objs) == 0: return 0
        return int(max(objs))

    def find_new_objs(self, track_mask, seg_mask):
        '''
        Compare tracked results from AOT with segmented results from SAM. Select objects from background if they are not tracked.
        Arguments:
            track_mask: numpy array (h,w)
            seg_mask: numpy array (h,w)
        Return:
            new_obj_mask: numpy array (h,w)
        '''
        new_obj_mask = (track_mask==0) * seg_mask
        new_obj_ids = np.unique(new_obj_mask)
        new_obj_ids = new_obj_ids[new_obj_ids!=0]
        # obj_num = self.get_obj_num() + 1
        obj_num = self.curr_idx
        for idx in new_obj_ids:
            new_obj_area = np.sum(new_obj_mask==idx)
            obj_area = np.sum(seg_mask==idx)
            if new_obj_area/obj_area < self.min_new_obj_iou or new_obj_area < self.min_area\
                or obj_num > self.max_obj_num:
                new_obj_mask[new_obj_mask==idx] = 0
            else:
                new_obj_mask[new_obj_mask==idx] = obj_num
                obj_num += 1
        return new_obj_mask
        
    def restart_tracker(self):
        self.tracker.restart()

    def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):
        ''''
        Use bbox-prompt to get mask
        Parameters:
            origin_frame: H, W, C
            bbox: [[x0, y0], [x1, y1]]
        Return:
            refined_merged_mask: numpy array (h, w)
            masked_frame: numpy array (h, w, c)
        '''
        # get interactive_mask
        interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
        refined_merged_mask = self.add_mask(interactive_mask)

        # draw mask
        masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

        # draw bbox
        masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))

        return refined_merged_mask, masked_frame

    def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
        '''
        Use point-prompt to get mask
        Parameters:
            origin_frame: H, W, C
            coords: nd.array [[x, y]]
            modes: nd.array [[1]]
        Return:
            refined_merged_mask: numpy array (h, w)
            masked_frame: numpy array (h, w, c)
        '''
        # get interactive_mask
        interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)

        refined_merged_mask = self.add_mask(interactive_mask)

        # draw mask
        masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)

        # draw points
        # self.everything_labels = np.array(self.everything_labels).astype(np.int64)
        # self.everything_points = np.array(self.everything_points).astype(np.int64)

        masked_frame = draw_points(coords, modes, masked_frame)

        # draw outline
        masked_frame = draw_outline(interactive_mask, masked_frame)

        return refined_merged_mask, masked_frame

    def add_mask(self, interactive_mask: np.ndarray):
        '''
        Merge interactive mask with self.origin_merged_mask
        Parameters:
            interactive_mask: numpy array (h, w)
        Return:
            refined_merged_mask: numpy array (h, w)
        '''
        if self.origin_merged_mask is None:
            self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8)

        refined_merged_mask = self.origin_merged_mask.copy()
        refined_merged_mask[interactive_mask > 0] = self.curr_idx

        return refined_merged_mask
    
    def detect_and_seg(self, origin_frame: np.ndarray, grounding_caption, box_threshold, text_threshold, box_size_threshold=1, reset_image=False):
        '''
        Using Grounding-DINO to detect object acc Text-prompts
        Retrun:
            refined_merged_mask: numpy array (h, w)
            annotated_frame: numpy array (h, w, 3)
        '''
        # backup id and origin-merged-mask
        bc_id = self.curr_idx
        bc_mask = self.origin_merged_mask

        # get annotated_frame and boxes
        annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold)
        for i in range(len(boxes)):
            bbox = boxes[i]
            if (bbox[1][0] - bbox[0][0]) * (bbox[1][1] - bbox[0][1]) > annotated_frame.shape[0] * annotated_frame.shape[1] * box_size_threshold:
                continue
            interactive_mask = self.sam.segment_with_box(origin_frame, bbox, reset_image)[0]
            refined_merged_mask = self.add_mask(interactive_mask)
            self.update_origin_merged_mask(refined_merged_mask)
            self.curr_idx += 1

        # reset origin_mask
        self.reset_origin_merged_mask(bc_mask, bc_id)

        return refined_merged_mask, annotated_frame

if __name__ == '__main__':
    from model_args import segtracker_args,sam_args,aot_args

    Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
    
    # ------------------ detect test ----------------------
    
    origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
    origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)
    grounding_caption = "swan.water"
    box_threshold = 0.25
    text_threshold = 0.25

    predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
    masked_frame = draw_mask(annotated_frame, predicted_mask)
    origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR)

    cv2.imwrite('./debug/masked_frame.png', masked_frame)
    cv2.imwrite('./debug/x.png', annotated_frame)

4. Error reporting

1. Download model problem

requests.exceptions.SSLError: (MaxRetryError(“HTTPSConnectionPool(host=‘huggingface.co’, port=443): Max retries exceeded with url: /bert-base-uncased/resolve/main/tokenizer_config.json (Caused by SSLError(SSLEOFError(8, ‘EOF occurred in violation of protocol (_ssl.c:997)’)))”), ‘(Request ID: d4f21f96-45fd-47a1-9afb-b7e4260a6f3b)’)

https://huggingface.co/bert-base-uncased/tree/main

Insert image description here
You can manually download the model from here and put it in the specified directory:
Insert image description here

2. imageio version problem

TypeError: The keyword fps is no longer supported. Use duration(in ms) instead, e.g. fps=50 == duration=20 (1000 * 1/50).

pip uninstall imageio
pip install imageio==2.23.0

Guess you like

Origin blog.csdn.net/matt45m/article/details/133110802