Python 将深度学习目标检测的结果框转化为矢量的shapefile

在遥感影像的目标检测中,我们通常希望将检测结果与原始影像进行叠加,以便查看和分析。最简单的方法就是将检测结果输出成shapefile的形式,下面提供一种基于Python的转换方法

import os

import gdal
import geopandas as gpd
import ogr
import osr
import rasterio.features
import shapely


def box_list2shp(det_file, img_file, out_shapefile):
    """
    将一系列坐标点的边界框数据转换成shapefile
    :param det_file: 输入的检测结果文件,每一行为一个检测框(x1, y1, x2, y2, x3, y3, x4, y4, label)
    :param img_file: 输入的原始影像路径
    :param out_shapefile: 输出的矢量路径
    """

    bbox_data, label_data = get_box_from_txt(det_file, img_file)
    # bbox_data, label_data = get_box_from_array(det_file, img_file)
    with rasterio.open(img_file) as raster:   # 从原始影像中获取投影和几何信息
        crs = raster.crs

    polygon_list = []
    for i in range(len(bbox_data)):
        polygon = shapely.geometry.Polygon(bbox_data[i])
        polygon_list.append(polygon)

    out_data = gpd.GeoSeries(polygon_list, index=label_data, crs=crs)
    out_data.to_file(out_shapefile, driver='ESRI Shapefile', encoding='utf-8')
    print("successfully convert box-list to shapefile")

其中,get_box_from_txt函数是从DOTA格式的txt文件中读取检测框坐标和对应标签,代码如下:

import gdal

def get_box_from_txt(txt_file, img_file):
    """
    从txt文件中读取目标检测的边界框坐标点和标签信息
    :param img_file: 原始影像数据,为了获取投影信息
    :return box_data: 由图像坐标点组成的一系列边界框
    :return label_data : 边界框对应的标签信息
    """

    dataset = gdal.Open(img_file)
    with open(txt_file, 'r', encoding='utf-8') as f:
        bbox_data = []
        label_data = []
        for line in f.readlines():
            curLine = line.strip().split(" ")
            x1 = float(curLine[0])
            y1 = float(curLine[1])
            x2 = float(curLine[2])
            y2 = float(curLine[3])
            x3 = float(curLine[4])
            y3 = float(curLine[5])
            x4 = float(curLine[6])
            y4 = float(curLine[7])
            if dataset.GetProjection() is None:  # 没有投影则需要进行这种转换
                box = [(x1, -y1), (x2, -y2), (x3, -y3), (x4, -y4)]
            else:    # 图像坐标转地理坐标
                x1, y1 = imagexy2geo(dataset, y1, x1)
                x2, y2 = imagexy2geo(dataset, y2, x2)
                x3, y3 = imagexy2geo(dataset, y3, x3)
                x4, y4 = imagexy2geo(dataset, y4, x4)
                box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
            label = curLine[8]
            bbox_data.append(box)
            label_data.append(label)
    return bbox_data, label_data

get_box_from_array函数是从检测结果的坐标数组中读取检测框和对应标签,代码如下:

def get_box_from_array(result_array, classnames, img_file):
    """
    从输出的list中读取目标检测的边界框坐标点和标签信息
    :param result_array: 检测结果的list文件,格式为[ class_num, 9, obj_num], 9代表四个坐标点和置信度
    :param classnames: 类别列表
    :param img_file: 原始影像数据,为了获取投影信息
    :return box_data: 由图像坐标点组成的一系列边界框
    :return label_data : 边界框对应的标签信息
    """

    dataset = gdal.Open(img_file)
    bbox_data = []
    label_data = []
    for idx, class_result in enumerate(result_array):
        for result in class_result:
            x1 = float(result[0])
            y1 = float(result[1])
            x2 = float(result[2])
            y2 = float(result[3])
            x3 = float(result[4])
            y3 = float(result[5])
            x4 = float(result[6])
            y4 = float(result[7])
            if dataset.GetProjection() is None:
                box = [(x1, -y1), (x2, -y2), (x3, -y3), (x4, -y4)]
            else:
                x1, y1 = imagexy2geo(dataset, y1, x1)
                x2, y2 = imagexy2geo(dataset, y2, x2)
                x3, y3 = imagexy2geo(dataset, y3, x3)
                x4, y4 = imagexy2geo(dataset, y4, x4)
                box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
            label = classnames[idx]
            bbox_data.append(box)
            label_data.append(label)
    return bbox_data, label_data

imagexy2geo 图像坐标转地理坐标的代码如下:

import gdal 

def imagexy2geo(dataset, row, col):
    '''
    根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
    :param dataset: GDAL地理数据,gdal.Open("xxx.tif")
    :param row: 像素的行号
    :param col: 像素的列号
    :return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
    '''

    trans = dataset.GetGeoTransform()
    px = trans[0] + col * trans[1] + row * trans[2]
    py = trans[3] + col * trans[4] + row * trans[5]
    return px, py

注:记得在文件中import相关包

Supongo que te gusta

Origin blog.csdn.net/MLH7M/article/details/121103082
Recomendado
Clasificación