接上一篇制作出RGB 和nDSM后,批量合成波段
from osgeo import gdal
import os
import numpy as np
#读取RGBtif文件函数
def readRGBTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return
# im_width = dataset.RasterXSize #栅格矩阵的列数
# im_height = dataset.RasterYSize #栅格矩阵的行数
im_width = 6060 #栅格矩阵的列数
im_height = 6060 #栅格矩阵的行数
im_bands = dataset.RasterCount #波段数
im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
im_proj = dataset.GetProjection()#获取投影信息
im_blueBand = im_data[0,0:im_height,0:im_width]#获取蓝波段
im_greenBand = im_data[1,0:im_height,0:im_width]#获取绿波段
im_redBand = im_data[2,0:im_height,0:im_width]#获取红波段
#im_nirBand = im_data[3,0:im_height,0:im_width]#获取近红外波段
im_dtype = im_data.dtype.name
return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_blueBand, im_greenBand, im_redBand
#读取DEMtif文件函数
def readDEMTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return
# im_width = dataset.RasterXSize #栅格矩阵的列数
# im_height = dataset.RasterYSize #栅格矩阵的行数
im_width = 6060 #栅格矩阵的列数
im_height = 6060 #栅格矩阵的行数
im_bands = dataset.RasterCount #波段数
im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
# im_data=np.round(im_data)
# n_max=np.max(im_data)
# n_min=np.min(im_data)
# im_data=(im_data-n_min)/(n_max-n_min)*255
im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
im_proj = dataset.GetProjection()#获取投影信息
im_DEMBand = im_data[0:im_height,0:im_width]#获取蓝波段
im_dtype = im_data.dtype.name
return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_DEMBand
#保存tif文件函数
def writeTiff(im_data,im_width,im_height,im_bands,im_geotrans,im_proj,path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1,im_data.shape
#创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def get_file_names(data_dir, file_type = ['tif','tiff']):
result_dir = []
result_name = []
for maindir, subdir, file_name_list in os.walk(data_dir):
for filename in file_name_list:
apath = maindir+'/'+filename
ext = apath.split('.')[-1]
if ext in file_type:
result_dir.append(apath)
result_name.append(filename)
else:
pass
return result_dir, result_name
in_dir1 = 'D:/toWuda/ds/DOM'
in_dir2 = 'D:/toWuda/ds/DEM'
out_dir = 'D:/toWuda/ds/four'
file_type = 'tif'
data_dir_list1,_ = get_file_names(in_dir1, file_type)
data_dir_list2,_ = get_file_names(in_dir2, file_type)
#data_dir_list = data_dir_list1 + data_dir_list2
for each_index, each_dir in enumerate(data_dir_list1):
img1, width1, height1, bands1, geotrans1, proj1,dtype1, blueband, greenband, redband = readRGBTif(each_dir)
img2, width2, height2, bands2, geotrans2, proj2,dtype2, DEMband = readDEMTif(data_dir_list2[each_index])
print(each_dir)
print(dtype1)
if 'int8' in dtype1:
datatype = gdal.GDT_Byte
elif 'int16' in dtype1:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
driver = gdal.GetDriverByName("GTiff")
# print(type(driver))
each_out_dir = out_dir + '/' + each_dir.split('/')[-1]
#each_out_dir = 'C:/Users/Dell/Desktop/guigang/trans_6bands.tif'
#datatype = 'uint8'
print('each_out_dir: ', each_out_dir)
new_dataset = driver.Create(each_out_dir, width1, height1, bands1+bands2, gdal.GDT_Byte)
print(type(new_dataset))
#print(each_out_dir)
new_dataset.SetGeoTransform(geotrans1)
new_dataset.SetProjection(proj1)
new_dataset.GetRasterBand(1).WriteArray(blueband)
new_dataset.GetRasterBand(2).WriteArray(greenband)
new_dataset.GetRasterBand(3).WriteArray(redband)
new_dataset.GetRasterBand(4).WriteArray(DEMband)
new_dataset = None
print('combine over')