segment-anything local deployment uses

Preface
Segment Anything Model (SAM) is an advanced image segmentation model based on the Foundation Model3 released by Facebook AI in 2020, which can accurately segment any object in an image based on simple input prompts (such as points or boxes), and Adapt to unfamiliar objects and images without additional training4. It utilizes traditional computer vision techniques and deep learning algorithms, trained on a huge dataset covering 11 million images and 1.1 billion masks, and exhibits excellent zero-shot performance.

1: Environment and software
win10
Miniconda3 (self-installation, a bunch of online tutorials, not introduced here)
GPU (rtx 3060TI 8G)

2: Installation
Official Instructions
insert image description here

1> Download segment-anything
Download address: https://github.com/facebookresearch/segment-anything
Unzip it and put it in a directory,
put it here in F:\gameai\segment-anything
insert image description here
2> Model data download address
Model file download to default or vit_h in the segment-anything directory
:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

vit_l:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

vit_b:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

3 weight files, the base is the smallest, the large is medium, and the huge is the largest. Choose one according to the video memory of the graphics card.
vit_b video memory 6G, others are 8G, haven't tried it (no 6G graphics card)

3> Create a conda environment
<1> Click Anaconda Prompt (Miniconda3) to open the conda command interface
insert image description here
<2> Create a virtual environment (environment name segment-anything)
conda create -n segment-anything python=3.10 #Create environment and install python3.10
conda activate segment-anything #Enter the environment

[1] Go to https://pytorch.org/get-started/locally/ to generate running commands according to your own configuration.
The local use is cuda11.7
to view the local cuda version
cmd
nvidia-smi
insert image description here

insert image description here

Execute pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 in the segment-anything environment
insert image description here
to test whether Pytorch is enabled GPU
display True means OK insert image description here
to exit python command exit()

 cd segment-anything
 pip install -e . 
 pip install opencv-python pycocotools matplotlib onnxruntime onnx

The input and output models of 2 innovative directories
are also placed here
insert image description here

Put a chapter picture in the input (put segment-anything\assets\notebook2.png here first)


**3:测试**
 如果GPU 显存小  --checkpoint sam_vit_h_4b8939.pth --model-type vit_h 换下  --checkpoint sam_vit_b_01ec64.pth --model-type vit_b
 如果显存 更小  参考 https://github.com/gaomingqi/Track-Anything/issues/4    
```bash
python scripts/amg.py --checkpoint sam_vit_h_4b8939.pth --model-type vit_h --input F:\gameai\segment-anything\input --output F:\gameai\segment-anything\output

A script I found online (I forgot the specific address, it was written by someone else) Click to select the area
test.py The content is as follows

import cv2
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor

input_dir = 'input'
output_dir = 'output'
crop_mode=True#是否裁剪到最小范围
#alpha_channel是否保留透明通道
print('最好是每加一个点就按w键predict一次')
os.makedirs(output_dir, exist_ok=True)
image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg','.JPG','.JPEG','.PNG'))]

sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
_ = sam.to(device="cuda")#注释掉这一行,会用cpu运行,速度会慢很多
predictor = SamPredictor(sam)#SAM预测图像
def mouse_click(event, x, y, flags, param):#鼠标点击事件
    global input_point, input_label, input_stop#全局变量,输入点,
    if not input_stop:#判定标志是否停止输入响应了!
        if event == cv2.EVENT_LBUTTONDOWN :#鼠标左键
            input_point.append([x, y])
            input_label.append(1)#1表示前景点
        elif event == cv2.EVENT_RBUTTONDOWN :#鼠标右键
            input_point.append([x, y])
            input_label.append(0)#0表示背景点
    else:
        if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN :#提示添加不了
            print('此时不能添加点,按w退出mask选择模式')


def apply_mask(image, mask, alpha_channel=True):#应用并且响应mask
    if alpha_channel:
        alpha = np.zeros_like(image[..., 0])#制作掩体
        alpha[mask == 1] = 255#兴趣地方标记为1,且为白色
        image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))#融合图像
    else:
        image = np.where(mask[..., None] == 1, image, 0)
    return image

def apply_color_mask(image, mask, color, color_dark = 0.5):#对掩体进行赋予颜色
    for c in range(3):
        image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
    return image

def get_next_filename(base_path, filename):#进行下一个图像
    name, ext = os.path.splitext(filename)
    for i in range(1, 101):
        new_name = f"{
      
      name}_{
      
      i}{
      
      ext}"
        if not os.path.exists(os.path.join(base_path, new_name)):
            return new_name
    return None

def save_masked_image(image, mask, output_dir, filename, crop_mode_):#保存掩盖部分的图像(感兴趣的图像)
    if crop_mode_:
        y, x = np.where(mask)
        y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
        cropped_mask = mask[y_min:y_max+1, x_min:x_max+1]
        cropped_image = image[y_min:y_max+1, x_min:x_max+1]
        masked_image = apply_mask(cropped_image, cropped_mask)
    else:
        masked_image = apply_mask(image, mask)
    filename = filename[:filename.rfind('.')]+'.png'
    new_filename = get_next_filename(output_dir, filename)
    
    if new_filename:
        if masked_image.shape[-1] == 4:
            cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
        else:
            cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
        print(f"Saved as {
      
      new_filename}")
    else:
        print("Could not save the image. Too many variations exist.")

current_index = 0

cv2.namedWindow("image")
cv2.setMouseCallback("image", mouse_click)
input_point = []
input_label = []
input_stop=False
while True:
    filename = image_files[current_index]
    image_orign = cv2.imread(os.path.join(input_dir, filename))
    image_crop = image_orign.copy()#原图裁剪
    image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB)#原图色彩转变
    selected_mask = None
    logit_input= None
    while True:
        #print(input_point)
        input_stop=False
        image_display = image_orign.copy()
        display_info = f'{
      
      filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point '
        cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
        for point, label in zip(input_point, input_label):#输入点和输入类型
            color = (0, 255, 0) if label == 1 else (0, 0, 255)
            cv2.circle(image_display, tuple(point), 5, color, -1)
        if selected_mask is not None :
            color = tuple(np.random.randint(0, 256, 3).tolist())
            selected_image = apply_color_mask(image_display,selected_mask, color)

        cv2.imshow("image", image_display)
        key = cv2.waitKey(1)

        if key == ord(" "):
            input_point = []
            input_label = []
            selected_mask = None
            logit_input= None
        elif key == ord("w"):
            input_stop=True
            if len(input_point) > 0 and len(input_label) > 0:
                #todo 预测图像
                predictor.set_image(image)#设置输入图像
                input_point_np = np.array(input_point)#输入暗示点,需要转变array类型才可以输入
                input_label_np = np.array(input_label)#输入暗示点的类型
                #todo 输入暗示信息,将返回masks
                masks, scores, logits= predictor.predict(
                    point_coords=input_point_np,
                    point_labels=input_label_np,
                    mask_input=logit_input[None, :, :] if logit_input is not None else None,
                    multimask_output=True,
                )

                mask_idx=0
                num_masks = len(masks)#masks的数量
                while(1):
                    color = tuple(np.random.randint(0, 256, 3).tolist())#随机列表颜色,就是
                    image_select = image_orign.copy()
                    selected_mask=masks[mask_idx]#选择msks也就是,a,d切换
                    selected_image = apply_color_mask(image_select,selected_mask, color)
                    mask_info = f'Total: {
      
      num_masks} | Current: {
      
      mask_idx} | Score: {
      
      scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save'
                    cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
                    #todo 显示在当前的图片,
                    cv2.imshow("image", selected_image)

                    key=cv2.waitKey(10)
                    if key == ord('q') and len(input_point)>0:
                        input_point.pop(-1)
                        input_label.pop(-1)
                    elif key == ord('s'):
                        save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
                    elif key == ord('a') :
                        if mask_idx>0:
                            mask_idx-=1
                        else:
                            mask_idx=num_masks-1
                    elif key == ord('d') :
                        if mask_idx<num_masks-1:
                            mask_idx+=1
                        else:
                            mask_idx=0
                    elif key == ord('w') :
                        break
                    elif key == ord(" "):
                        input_point = []
                        input_label = []
                        selected_mask = None
                        logit_input= None
                        break
                logit_input=logits[mask_idx, :, :]
                print('max score:',np.argmax(scores),' select:',mask_idx)

        elif key == ord('a'):
            current_index = max(0, current_index - 1)
            input_point = []
            input_label = []
            break
        elif key == ord('d'):
            current_index = min(len(image_files) - 1, current_index + 1)
            input_point = []
            input_label = []
            break
        elif key == 27:
            break
        elif key == ord('q') and len(input_point)>0:
            input_point.pop(-1)
            input_label.pop(-1)
        elif key == ord('s') and selected_mask is not None :
            save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)

    if key == 27:
        break

The result of the operation is as follows
insert image description here

Operation First press the left mouse button
w to select the point, you can click the right mouse button to select the background and choose the background according to the above prompts When the outline flashes, press w to confirm/press d to next outline/ Score The bigger the score, the better After confirming, press s to save the outline to output In the category (here I choose a dog, click on the dog’s head) is still relatively accurate, but not too smooth, a little improved






insert image description here

If you find it useful, please like it and add a favorite

Guess you like

Origin blog.csdn.net/yunteng521/article/details/130601007