Article Directory
Preface
In actual production and life, we need a large number of samples for training deep learning models, and remote sensing images are usually large. In order to facilitate automated mass production of slices, this blog post will briefly introduce this method.
Data overview
Original image
Anchor point
Sample template image
Code
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from osgeo import ogr
import os, sys
import numpy as np
import cv2
import numpy
import gdal
import time
import glob
from osgeo import osr
def del_file(path):
for i in os.listdir(path):
path_file = os.path.join(path, i)
if os.path.isfile(path_file):
os.remove(path_file)
else:
del_file(path_file)
def sampleClip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None):
time1 = time.clock()
# if not os.path.exists(outputdir):
# os.mkdir(outputdir)
# else:
# del_file(outputdir)
gdal.AllRegister()
lc = gdal.Open(tif)
im_width = lc.RasterXSize
im_height = lc.RasterYSize
im_geotrans = lc.GetGeoTransform()
bandscount = lc.RasterCount
im_proj = lc.GetProjection()
print(im_width, im_height)
gdal.AllRegister()
gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES")
driver = ogr.GetDriverByName('ESRI Shapefile')
dsshp = driver.Open(shp, 0)
if dsshp is None:
print('Could not open ' + 'sites.shp')
sys.exit(1)
layer = dsshp.GetLayer()
xValues = []
yValues = []
m = layer.GetFeatureCount()
feature = layer.GetNextFeature()
print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype,
int(size)))
if n is not None:
pass
else:
n = 1
while feature:
if n > 0 and n < 10:
dirname = "000000" + str(n)
elif n > 9 and n < 100:
dirname = "00000" + str(n)
elif n > 99 and n < 1000:
dirname = "0000" + str(n)
else:
dirname = "000" + str(n)
# print dirname
dirpath = os.path.join(outputdir, dirname + "_V1")
if not os.path.exists(dirpath):
os.mkdir(dirpath)
tifname = dirname + ".tif"
if "poly" in sampletype or "POLY" in sampletype:
shpname = dirname + "_V1_POLY.shp"
if "line" in sampletype or "LINE" in sampletype:
shpname = dirname + "_V1_LINE.shp"
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
print(x, y)
print(im_geotrans)
xValues.append(x)
yValues.append(y)
newform = []
newform = list(im_geotrans)
# print newform
newform[0] = x - im_geotrans[1] * int(size) / 2.0
newform[3] = y - im_geotrans[5] * int(size) / 2.0
print(newform[0], newform[3])
newformtuple = tuple(newform)
x1 = x - int(size) / 2 * im_geotrans[1]
y1 = y - int(size) / 2 * im_geotrans[5]
x2 = x + int(size) / 2 * im_geotrans[1]
y2 = y - int(size) / 2 * im_geotrans[5]
x3 = x - int(size) / 2 * im_geotrans[1]
y3 = y + int(size) / 2 * im_geotrans[5]
x4 = x + int(size) / 2 * im_geotrans[1]
y4 = y + int(size) / 2 * im_geotrans[5]
Xpix = (x1 - im_geotrans[0]) / im_geotrans[1]
# Xpix=(newform[0]-im_geotrans[0])
Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5]
# Ypix=abs(newform[3]-im_geotrans[3])
print("#################")
print(Xpix, Ypix)
# **************create tif**********************
# print"start creating {0}".format(tifname)
pBuf = None
pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size))
# print pBuf.dtype.name
driver = gdal.GetDriverByName("GTiff")
create_option = []
if 'int8' in pBuf.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in pBuf.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
outtif = os.path.join(dirpath, tifname)
ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option)
if ds == None:
print("2222")
ds.SetProjection(im_proj)
ds.SetGeoTransform(newformtuple)
ds.FlushCache()
for i in range(int(bandscount)):
outBand = ds.GetRasterBand(i + 1)
outBand.WriteArray(pBuf[i])
ds.FlushCache()
# print "creating {0} successfully".format(tifname)
# **************create shp**********************
# print"start creating shps"
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
gdal.SetConfigOption("SHAPE_ENCODING", "")
strVectorFile = os.path.join(dirpath, shpname)
ogr.RegisterAll()
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp)
layer0 = ds.GetLayerByIndex(0)
prosrs = layer0.GetSpatialRef()
# geosrs = osr.SpatialReference()
oDriver = ogr.GetDriverByName("ESRI Shapefile")
if oDriver == None:
print("1")
return
oDS = oDriver.CreateDataSource(strVectorFile)
if oDS == None:
print("2")
return
papszLCO = []
if "line" in sampletype or "LINE" in sampletype:
oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
if "poly" in sampletype or "POLY" in sampletype:
oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
if oLayer == None:
print("3")
return
oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString)
oFieldName.SetWidth(50)
oLayer.CreateField(oFieldName, 1)
oDefn = oLayer.GetLayerDefn()
oFeatureRectangle = ogr.Feature(oDefn)
geomRectangle = ogr.CreateGeometryFromWkt(
"POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3))
oFeatureRectangle.SetGeometry(geomRectangle)
oLayer.CreateFeature(oFeatureRectangle)
print("{0} ok".format(dirname))
n = n + 1
feature = layer.GetNextFeature()
time2 = time.clock()
print('Process Running time: %s min' % ((time2 - time1) / 60))
return n
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
if __name__ == "__main__":
from shutil import copyfile
outputdir = './plough' # 输出路径
mkdir(outputdir)
sampletype = "line" # 样本类型(线line或者面poly)
size = 1000 # 样本大小
n = 1 # 开始序号
fieldName = 'cls' # 字段名
tif = './Level18/cq.tif'
shp = 'train.shp'
n = sampleClip(shp, tif, outputdir, sampletype, size, fieldName, n)
print(n)