UNet - 预测数据predict(多个图像的分割)

目录

1. 介绍

2. predict 预测分割图片

3. 结果展示

4. 完整代码


1. 介绍

项目完整下载地址:UNet 网络对图像的分割

之前已经将unet的网络模块、dataset数据加载和train训练数据已经解决了,这次要将unet网络去分割图像,下面是之前的链接

unet 网络:UNet - unet网络

dataset 数据处理:UNet - 数据加载 Dataset

train 网络训练:UNet - 训练数据train

待分割的图像如下:

 存放的路径在U-net项目的predict里面

我们的目标是将predict里面所有的图片分割出来,按照名称顺序保存在result文件夹里面:

2. predict 预测分割图片

首先定义图片的预处理,按照dataset里面相同的方式进行预处理

然后是加载网络的模型和网络参数

 然后加载predict里面所有待处理图片的路径

需要注意的是,os.listdir 加载的只是里面每个图片,并不是图片的具体路径。tests_path 里面的内容如下面的注释所示:

接下来就可以分割图片了

因为tests_path 里面每个文件是 x.png 即文件名+后缀的方式。通过split的 '.' 分割成x和后缀名png的形式,[-2]代表取倒数第二个值,就可以将每个文件名x取出来,然后将路径拼接就可以存放到result里面

open图像的时候,也要注意,test_path 只是遍历tests_path 里面的文件,需要加上之前的predict路径才能正确的读取到每个待分割的图片

因为这里处理图像会改变size成480*480的形式,想要将输出的结果保持不变的话,在网络预测前将图像的大小保存下来就可以了。(注:这里的size和opencv里面的shape返回值是反过来的

这里不清楚的可以通过调试,打印每个变量的内容看一下就可以了

接下来就是网络预测的部分,这里输出的size是(batch,channel,height,width),因为这里的batch是1,channel 灰度图片因此也是1,这里通过squeeze将1的维度删去,只需要图像的大小

下面是squeeze的用法

然后图像保存的话,要转到cpu上面 ,这一步不知道为啥,但是不加这一步会报错

 最后就是保存图像了,将网络的结果二值化后,还原图像再保存就可以了

3. 结果展示

predict里面待预测的图片

result 里面分割好的图片

下面是 参考文章 博主的分割结果

 

对比发现,有些小的细节会丢失,但是大概的轮廓分割出来了

4. 完整代码

完整的项目可以在 这里 下载

import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Image
import os


# 预处理
transform = transforms.Compose([
    transforms.Resize((480,480)),        # 缩放图像
    transforms.ToTensor(),
])

# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('Unet.pth', map_location=device))
net.to(device)

# 测试模式
net.eval()
# 读取所有图片路径
tests_path = os.listdir('./predict/')   # 获取 './predict/' 路径下所有文件,这里的路径只是里面文件的路径
''''
print(tests_path)
['0.png', '1.png', '10.png', '11.png', '12.png', '13.png', '14.png', 
'15.png', '16.png', '17.png', '18.png', '19.png', '2.png', '20.png', 
'21.png', '22.png', '23.png', '24.png', '25.png', '26.png', '27.png',
 '28.png', '29.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']
'''


with torch.no_grad():                   # 预测的时候不需要计算梯度
    for test_path in tests_path:        # 遍历每个predict的文件
        save_pre_path = './result/'+test_path.split('.')[-2] + '_res.png'    # 将保存的路径按照原图像的后缀,按照数字排序保存
        img = Image.open('./predict/' +test_path)           # 预测图片的路径
        width,height = img.size[0],img.size[1]              # 保存图像的大小
        img = transform(img)
        img = torch.unsqueeze(img,dim = 0)                  # 扩展图像的维度

        pred = net(img.to(device))                          # 网络预测
        pred = torch.squeeze(pred)                          # 将(batch、channel)维度去掉
        pred = np.array(pred.data.cpu())                    # 保存图片需要转为cpu处理

        pred[pred >= 0] = 255                               # 处理结果二值化
        pred[pred < 0] = 0

        pred = np.uint8(pred)                               # 转为图片的形式
        pred = cv2.resize(pred,(width,height),cv2.INTER_CUBIC)          # 还原图像的size
        cv2.imwrite(save_pre_path, pred)                    # 保存图片

猜你喜欢

转载自blog.csdn.net/qq_44886601/article/details/127920076#comments_27562159