Mass production of remote sensing image sample templates based on python script

Preface

In actual production and life, we need a large number of samples for training deep learning models, and remote sensing images are usually large. In order to facilitate automated mass production of slices, this blog post will briefly introduce this method.

Data overview

Original image

Insert picture description here

Anchor point

Insert picture description here

Sample template image

Insert picture description here

Code

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from osgeo import ogr
import os, sys
import numpy as np
import cv2
import numpy
import gdal
import time
import glob
from osgeo import osr


def del_file(path):
    for i in os.listdir(path):
        path_file = os.path.join(path, i)
        if os.path.isfile(path_file):
            os.remove(path_file)
        else:
            del_file(path_file)


def sampleClip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None):
    time1 = time.clock()
    # if not os.path.exists(outputdir):
    #     os.mkdir(outputdir)
    # else:
    #     del_file(outputdir)

    gdal.AllRegister()
    lc = gdal.Open(tif)
    im_width = lc.RasterXSize
    im_height = lc.RasterYSize
    im_geotrans = lc.GetGeoTransform()
    bandscount = lc.RasterCount
    im_proj = lc.GetProjection()
    print(im_width, im_height)
    gdal.AllRegister()
    gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES")

    driver = ogr.GetDriverByName('ESRI Shapefile')
    dsshp = driver.Open(shp, 0)
    if dsshp is None:
        print('Could not open ' + 'sites.shp')
        sys.exit(1)
    layer = dsshp.GetLayer()
    xValues = []
    yValues = []
    m = layer.GetFeatureCount()
    feature = layer.GetNextFeature()
    print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype,
                                                                                      int(size)))

    if n is not None:
        pass
    else:
        n = 1
    while feature:
        if n > 0 and n < 10:
            dirname = "000000" + str(n)
        elif n > 9 and n < 100:
            dirname = "00000" + str(n)
        elif n > 99 and n < 1000:
            dirname = "0000" + str(n)
        else:
            dirname = "000" + str(n)
        # print dirname
        dirpath = os.path.join(outputdir, dirname + "_V1")
        if not os.path.exists(dirpath):
            os.mkdir(dirpath)
        tifname = dirname + ".tif"
        if "poly" in sampletype or "POLY" in sampletype:
            shpname = dirname + "_V1_POLY.shp"
        if "line" in sampletype or "LINE" in sampletype:
            shpname = dirname + "_V1_LINE.shp"
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        print(x, y)
        print(im_geotrans)
        xValues.append(x)
        yValues.append(y)
        newform = []
        newform = list(im_geotrans)
        # print newform
        newform[0] = x - im_geotrans[1] * int(size) / 2.0
        newform[3] = y - im_geotrans[5] * int(size) / 2.0
        print(newform[0], newform[3])
        newformtuple = tuple(newform)
        x1 = x - int(size) / 2 * im_geotrans[1]
        y1 = y - int(size) / 2 * im_geotrans[5]
        x2 = x + int(size) / 2 * im_geotrans[1]
        y2 = y - int(size) / 2 * im_geotrans[5]
        x3 = x - int(size) / 2 * im_geotrans[1]
        y3 = y + int(size) / 2 * im_geotrans[5]
        x4 = x + int(size) / 2 * im_geotrans[1]
        y4 = y + int(size) / 2 * im_geotrans[5]
        Xpix = (x1 - im_geotrans[0]) / im_geotrans[1]
        # Xpix=(newform[0]-im_geotrans[0])

        Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5]
        # Ypix=abs(newform[3]-im_geotrans[3])
        print("#################")
        print(Xpix, Ypix)

        # **************create tif**********************
        # print"start creating {0}".format(tifname)
        pBuf = None
        pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size))
        # print pBuf.dtype.name
        driver = gdal.GetDriverByName("GTiff")
        create_option = []
        if 'int8' in pBuf.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in pBuf.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        outtif = os.path.join(dirpath, tifname)
        ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option)
        if ds == None:
            print("2222")
        ds.SetProjection(im_proj)
        ds.SetGeoTransform(newformtuple)
        ds.FlushCache()
        for i in range(int(bandscount)):
            outBand = ds.GetRasterBand(i + 1)
            outBand.WriteArray(pBuf[i])
        ds.FlushCache()
        # print "creating {0} successfully".format(tifname)
        # **************create shp**********************
        # print"start creating shps"
        gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
        gdal.SetConfigOption("SHAPE_ENCODING", "")
        strVectorFile = os.path.join(dirpath, shpname)
        ogr.RegisterAll()
        driver = ogr.GetDriverByName('ESRI Shapefile')
        ds = driver.Open(shp)
        layer0 = ds.GetLayerByIndex(0)
        prosrs = layer0.GetSpatialRef()
        # geosrs = osr.SpatialReference()

        oDriver = ogr.GetDriverByName("ESRI Shapefile")
        if oDriver == None:
            print("1")
            return

        oDS = oDriver.CreateDataSource(strVectorFile)
        if oDS == None:
            print("2")
            return

        papszLCO = []
        if "line" in sampletype or "LINE" in sampletype:
            oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
        if "poly" in sampletype or "POLY" in sampletype:
            oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
        if oLayer == None:
            print("3")
            return

        oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString)
        oFieldName.SetWidth(50)
        oLayer.CreateField(oFieldName, 1)
        oDefn = oLayer.GetLayerDefn()
        oFeatureRectangle = ogr.Feature(oDefn)

        geomRectangle = ogr.CreateGeometryFromWkt(
            "POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3))
        oFeatureRectangle.SetGeometry(geomRectangle)
        oLayer.CreateFeature(oFeatureRectangle)
        print("{0} ok".format(dirname))
        n = n + 1
        feature = layer.GetNextFeature()
    time2 = time.clock()
    print('Process Running time: %s min' % ((time2 - time1) / 60))

    return n


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)


if __name__ == "__main__":
    from shutil import copyfile

    outputdir = './plough'  # 输出路径
    mkdir(outputdir)
    sampletype = "line"  # 样本类型(线line或者面poly)
    size = 1000  # 样本大小
    n = 1  # 开始序号
    fieldName = 'cls'  # 字段名
    tif = './Level18/cq.tif'
    shp = 'train.shp'
    n = sampleClip(shp, tif, outputdir, sampletype, size, fieldName, n)
    print(n)

Guess you like

Origin blog.csdn.net/weixin_42990464/article/details/111187260