DSM训练

一、裁剪 删除错误的DSM图像 删除空白标签

 
import cv2
import os
 
# Cutting the input image to h*w blocks
 
inPath = "./dataset/sat_train/"
outPath = "./dataset/train/"
inPath2 = "./dataset/mask_train/"
 
for f in os.listdir(inPath):
    path = inPath + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_sat.tif'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)
 
for f in os.listdir(inPath2):
    path = inPath2 + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_mask.png'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)
print("finish!") 
 
mask_names = filter(lambda x: x.find('mask')!=-1, os.listdir(outPath))
# sat_names = filter(lambda x: x.find('sat')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in mask_names:
    path = outPath + f.strip()
    if not os.path.exists(path):
        continue;    
    img = cv2.imread(path,0)             
    if cv2.countNonZero(img) == 0:
       print(f+'Image is black')
       path2=f[:-9]
       os.remove(path)
       os.remove(outPath +path2 + "_sat.tif")
 

二、fenxi

import os
import shutil
data_path='./submits/log01_Dink101_five_100/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_100_excel.txt"),'r').read().splitlines()
valid_path='./dataset/valid/'
rgb_path='./dataset/valid_all/'
real_path='./dataset/real/'
 
iou_100=os.path.join(data_path,'iou_100/')
iou_80=os.path.join(data_path,'iou_80/')
iou_50=os.path.join(data_path,'iou_50/')
iou_30=os.path.join(data_path,'iou_30/')
if not os.path.exists(iou_100):
    os.mkdir(iou_100)
    os.mkdir(iou_80)
    os.mkdir(iou_50)
    os.mkdir(iou_30)
for n in data:
    name=n.split()[1]
    iou=float(n.split()[2])
    img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
    valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
    rgb_name=os.path.join(rgb_path,name[:-4]+'sat.tif')
    real_name=os.path.join(real_path,name[:-4]+'mask.png')
    if iou>=80:
        shutil.copy(img_path,iou_100)
        file_name=os.path.join(iou_100,name+'.png')
        new_name=os.path.join(iou_100,name[:-4]+'tmask_'+str(iou)+'.png')
        os.rename(file_name,new_name)

        shutil.copy(rgb_name,iou_100)
        file_name=os.path.join(iou_100,name[:-4]+'sat.tif')
        new_name=os.path.join(iou_100,name[:-4]+'rgb.tif')
        os.rename(file_name,new_name)

        shutil.copy(valid_name,iou_100)
        
        shutil.copy(real_name,iou_100)
        file_name=os.path.join(iou_100,name[:-4]+'mask.png')
        new_name=os.path.join(iou_100,name[:-4]+'tmask.png')
        os.rename(file_name,new_name)
        
        print(name,iou)
        continue
    elif iou>=50:
        shutil.copy(img_path,iou_80)
        file_name=os.path.join(iou_80,name+'.png')
        new_name=os.path.join(iou_80,name[:-4]+'tmask_'+str(iou)+'.png')
        os.rename(file_name,new_name)

        shutil.copy(rgb_name,iou_80)
        file_name=os.path.join(iou_80,name[:-4]+'sat.tif')
        new_name=os.path.join(iou_80,name[:-4]+'rgb.tif')
        os.rename(file_name,new_name)

        shutil.copy(valid_name,iou_80)
        
        shutil.copy(real_name,iou_80)
        file_name=os.path.join(iou_80,name[:-4]+'mask.png')
        new_name=os.path.join(iou_80,name[:-4]+'tmask.png')
        os.rename(file_name,new_name)
        
        print(name,iou)
        continue
    elif iou>=30:
        shutil.copy(img_path,iou_50)
        file_name=os.path.join(iou_50,name+'.png')
        new_name=os.path.join(iou_50,name[:-4]+'tmask_'+str(iou)+'.png')
        os.rename(file_name,new_name)

        shutil.copy(rgb_name,iou_50)
        file_name=os.path.join(iou_50,name[:-4]+'sat.tif')
        new_name=os.path.join(iou_50,name[:-4]+'rgb.tif')
        os.rename(file_name,new_name)

        shutil.copy(valid_name,iou_50)
        
        shutil.copy(real_name,iou_50)
        file_name=os.path.join(iou_50,name[:-4]+'mask.png')
        new_name=os.path.join(iou_50,name[:-4]+'tmask.png')
        os.rename(file_name,new_name)
        
        print(name,iou)
        continue
    else:
        shutil.copy(img_path,iou_30)
        file_name=os.path.join(iou_30,name+'.png')
        new_name=os.path.join(iou_30,name[:-4]+'tmask_'+str(iou)+'.png')
        os.rename(file_name,new_name)

        shutil.copy(rgb_name,iou_30)
        file_name=os.path.join(iou_30,name[:-4]+'sat.tif')
        new_name=os.path.join(iou_30,name[:-4]+'rgb.tif')
        os.rename(file_name,new_name)

        shutil.copy(valid_name,iou_30)
        
        shutil.copy(real_name,iou_30)
        file_name=os.path.join(iou_30,name[:-4]+'mask.png')
        new_name=os.path.join(iou_30,name[:-4]+'tmask.png')
        os.rename(file_name,new_name)
        
        print(name,iou)
        continue
 
 
print('Finish')

3 select iou>30

import os
import shutil
data_path='./submits/log01_Dink101_five_100/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_100_excel.txt"),'r').read().splitlines()

iou_100=os.path.join(data_path,'test_pre_img/')

if not os.path.exists(iou_100):
    os.mkdir(iou_100)
for n in data:
    name=n.split()[1]
    iou=float(n.split()[2])
    img_path=os.path.join(data_path,'test_pre_img87.24/'+name+'.png')

    if iou>=30:
        shutil.copy(img_path,iou_100)
        print(name,iou)
        continue
 
print('Finish')

delete real

import os
import cv2
# source = 'dataset/sat_train/'
real_path ="./dataset/real/"
pre_path ="./submits/log01_Dink101_five_100/test_iou/test_pre_img/"
 
real_names = filter(lambda x: x.find('mask')!=-1, os.listdir(real_path))
pre_names = filter(lambda x: x.find('mask')!=-1, os.listdir(pre_path))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in real_names:
    pre_name = pre_path + f.strip()
    if not os.path.exists(pre_name):
        os.remove(real_path + f.strip())
        print(real_path + f.strip())
# for f in sat_names:
#     mask_path = tar + f.strip()[:-8] + "_mask.png"
#     if not os.path.exists(mask_path):
#         os.remove(tar + f.strip())
#         print(tar + f.strip())

predict

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
 
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔
 
 
from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
 
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
 
BATCHSIZE_PER_CARD = 32
 
# class TTAFrame():
#     def __init__(self, net):
#         self.net = net().cuda()
#         self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
 
#     def load(self, path):
     #   new_state_dict = OrderedDict()
      #  for key, value in torch.load(path).items():
       #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        # self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'
 
def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close
 
def savetrainList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt_train.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close
 
def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))
 
 
 
print("开始运行!")
# source = 'dataset/test/'
 
# solver = TTAFrame(DinkNet34)
# solver = TTAFrame(DinkNet50)
weight_dir      =  "./weights/"
 
weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))
 
 
save_valid_dir='./dataset/valid_train/'
test_num=len(os.listdir('./dataset/valid/'))
# if not os.path.exists(save_valid_dir):
#    trainsample('./dataset/train/',test_num)
 
mylog = open('submits/count_low_pic.log','w')
 
 
source = 'dataset/valid/'
test_valid = os.listdir(source)
test_num=len(os.listdir('./dataset/valid/'))
 
for weight_name in weight_list:
    weight_path=os.path.join(weight_dir,weight_name )
    # solver.load('weights/log01_Dink34.th')
    # solver.load(weight_path)
    tic = time()
    tar=os.path.join('./submits/',weight_name[:-3])
    
    
    target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
    lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
    higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
    test_pre_img_dir=os.path.join(target,'test_pre_img/')
 
 
    
    
    #wtn:精度计算
    miou_mode       = 2
    #------------------------------#
    #   分类个数+1、如2+1
    #------------------------------#
    num_classes     = 2
    #--------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    #--------------------------------------------#
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes    = ["nonwater","water"]
    #-------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    #-------------------------------------------------------#
    data_path  = './dataset/'
    data_train_path='./dataset/'
 
 
    f=open("./dataset/gt.txt", 'w')
    gt_dir      = os.path.join(data_path, "real/")
    pred_dir    = test_pre_img_dir
    path_list = os.listdir(gt_dir)
    path_list.sort()
    dirList(gt_dir,path_list)
    saveList(path_list)
    image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() 
   
    train_mIou=[]
    train_mPA=[]
    test_mIou=[]
    test_mPA=[]
 
    if miou_mode == 0 or miou_mode == 2:
        
        mylog.write(str(weight_name[:-3]))
 
        print(weight_name +"  Get miou.")
 
    
 
        
 
        print('计算测试miou')
        test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
        mylog.write('  test_mIoU:  '+str(test_miou))
        mylog.write('  test_mPA:  '+str(test_mpa))
        print('  test_mIoU:  '+str(test_miou))
      
      
        
        count=0
        print('计算测试样本单张iou')
        count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count)  # 执行计算mIoU的函数
        mylog.write('  low-iou test picture num:  '+str(count))
        print(weight_name + "Get miou done.")
 
       
 
mylog.write('Finish!')
print ('Finish!')
mylog.close()
 

 闭运算

先膨胀后腐蚀

用于排除前景对象中的小孔或对象上的小黑点

import cv2 as cv
import numpy as np
import os

pre_path='/mnt/sdb1/fenghaixia/dsm/test_pre_img89.49/'
for f in os.listdir(pre_path):
    image = cv.imread(pre_path+f)
    k = np.ones((5, 5), np.uint8)
    open = cv.morphologyEx(image, cv.MORPH_OPEN, k)
    cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/1/' + f, open)
    close = cv.morphologyEx(image, cv.MORPH_CLOSE, k)
    cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/2/' + f, close)
    close = cv.morphologyEx(open, cv.MORPH_CLOSE, k)
    cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/3/' + f, close)    
print('fnish')

只算iou

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔


from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU

from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool

BATCHSIZE_PER_CARD = 16

class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        
    def test_one_img_from_path(self, path, evalmode = True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)

    def test_one_img_from_path_1(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        
        mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        mask1 = mask[:4] + mask[4:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3

    def load(self, path):
     #   new_state_dict = OrderedDict()
      #  for key, value in torch.load(path).items():
       #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'

def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def savetrainList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt_train.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))


print("开始运行!")

weight_dir      =  "./weights/"

weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))


mylog = open('submits/count_low_pic.log','w')



for weight_name in weight_list:
    weight_path=os.path.join(weight_dir,weight_name )
    # solver.load('weights/log01_Dink34.th')
    # solver.load(weight_path)
    tic = time()
    tar=os.path.join('./submits/',weight_name[:-3])
    if not os.path.exists(tar):
        os.mkdir(tar)
    
    target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
    lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
    higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
    if not os.path.exists(target):
        os.mkdir(target)
    if not os.path.exists(lower_iou):  
        os.mkdir(lower_iou)
    if not os.path.exists(higher_iou):  
        os.mkdir(higher_iou)

    test_pre_img_dir=os.path.join(target,'3/')


       

    
    #wtn:精度计算
    miou_mode       = 2
    #------------------------------#
    #   分类个数+1、如2+1
    #------------------------------#
    num_classes     = 2
    #--------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    #--------------------------------------------#
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes    = ["nonwater","water"]
    #-------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    #-------------------------------------------------------#
    data_path  = './dataset/'
    data_train_path='./dataset/'


    f=open("./dataset/gt.txt", 'w')
    gt_dir      = os.path.join(data_path, "real/")
    pred_dir    = test_pre_img_dir
    path_list = os.listdir(gt_dir)
    path_list.sort()
    dirList(gt_dir,path_list)
    saveList(path_list)
    image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() 


    train_mIou=[]
    train_mPA=[]
    test_mIou=[]
    test_mPA=[]

    if miou_mode == 0 or miou_mode == 2:
        
        mylog.write(str(weight_name[:-3]))

        print(weight_name +"  Get miou.")

        print('计算测试miou')
        test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
        mylog.write('  test_mIoU:  '+str(test_miou))
        mylog.write('  test_mPA:  '+str(test_mpa))
        print('  test_mIoU:  '+str(test_miou))
            
    
        # count=0
        # print('计算测试样本单张iou')
        # count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count)  # 执行计算mIoU的函数
        # mylog.write('  low-iou test picture num:  '+str(count))
        # print(weight_name + "Get miou done.")

     

mylog.write('Finish!')
print ('Finish!')
mylog.close()

猜你喜欢

转载自blog.csdn.net/weixin_61235989/article/details/130109432
今日推荐