UNet - Prediction data predict (segmentation of multiple images)

Table of contents

1 Introduction

2. predict predicts the segmented image

3. Result display

4. Complete code


1 Introduction

Complete download address of the project: Image Segmentation by UNet Network

The unet network module, dataset data loading and train training data have been solved before. This time, the unet network will be used to segment the image. The following is the previous link

unet network: UNet - unet network

dataset data processing: UNet - data loading Dataset

train network training: UNet - training data train

The image to be segmented is as follows:

 The storage path is in the predict of the U-net project

Our goal is to segment all the pictures in the predict and save them in the result folder in order of name:

2. predict predicts the segmented image

First define the preprocessing of the image, and perform preprocessing in the same way as in the dataset

Then load the model and network parameters of the network

 Then load the paths of all pending images in predict

It should be noted that os.listdir only loads each picture in it, not the specific path of the picture . The contents of tests_path are shown in the comments below:

Then you can divide the picture

Because each file in the tests_path is x.png, that is, the file name + suffix. Split the '.' into x and the suffix name png, [-2] means take the penultimate value, you can take out each file name x, and then splicing the path can be stored in the result

When opening the image, also note that the test_path just traverses the files in the tests_path, you need to add the previous predict path to correctly read each image to be divided

Because the image processing here will change the size to 480*480, if you want to keep the output result unchanged, you can save the size of the image before network prediction. (Note: The return value of size and shape in opencv is reversed here )

If you are not clear here, you can debug and print the content of each variable to have a look.

Next is the part of network prediction. The output size here is (batch, channel, height, width), because the batch here is 1, and the grayscale image of the channel is therefore also 1. Here, the dimension of 1 is deleted by squeeze, just need image size

The following is the usage of squeeze

Then if the image is saved, it needs to be transferred to the cpu. I don’t know why this step, but if this step is not added, an error will be reported.

 The last thing is to save the image. After binarizing the result of the network, restore the image and save it.

3. Result display

The picture to be predicted in predict

The segmented image in result

The following is the segmentation result of the reference article blogger

 

The comparison found that some small details will be lost, but the approximate outline is segmented out

4. Complete code

The complete project can be downloaded here

import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Image
import os


# 预处理
transform = transforms.Compose([
    transforms.Resize((480,480)),        # 缩放图像
    transforms.ToTensor(),
])

# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('Unet.pth', map_location=device))
net.to(device)

# 测试模式
net.eval()
# 读取所有图片路径
tests_path = os.listdir('./predict/')   # 获取 './predict/' 路径下所有文件,这里的路径只是里面文件的路径
''''
print(tests_path)
['0.png', '1.png', '10.png', '11.png', '12.png', '13.png', '14.png', 
'15.png', '16.png', '17.png', '18.png', '19.png', '2.png', '20.png', 
'21.png', '22.png', '23.png', '24.png', '25.png', '26.png', '27.png',
 '28.png', '29.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']
'''


with torch.no_grad():                   # 预测的时候不需要计算梯度
    for test_path in tests_path:        # 遍历每个predict的文件
        save_pre_path = './result/'+test_path.split('.')[-2] + '_res.png'    # 将保存的路径按照原图像的后缀,按照数字排序保存
        img = Image.open('./predict/' +test_path)           # 预测图片的路径
        width,height = img.size[0],img.size[1]              # 保存图像的大小
        img = transform(img)
        img = torch.unsqueeze(img,dim = 0)                  # 扩展图像的维度

        pred = net(img.to(device))                          # 网络预测
        pred = torch.squeeze(pred)                          # 将(batch、channel)维度去掉
        pred = np.array(pred.data.cpu())                    # 保存图片需要转为cpu处理

        pred[pred >= 0] = 255                               # 处理结果二值化
        pred[pred < 0] = 0

        pred = np.uint8(pred)                               # 转为图片的形式
        pred = cv2.resize(pred,(width,height),cv2.INTER_CUBIC)          # 还原图像的size
        cv2.imwrite(save_pre_path, pred)                    # 保存图片

Guess you like

Origin blog.csdn.net/qq_44886601/article/details/127920076#comments_27562159