个人记录使用 深度学习提取建筑物 代码运行记录

 原代码在这:https://github.com/zlckanata/DeepGlobe-Road-Extraction-Challenge

这只是方便我自己跑自己的遥感数据用的。

一、准备数据(裁剪sat_train和mask_train文件中的图片为512*512,放在train中,删除空白)

clip.py


import cv2
import os

# Cutting the input image to h*w blocks

inPath = "./dataset/sat_train/"
outPath = "./dataset/train/"
inPath2 = "./dataset/mask_train/"

for f in os.listdir(inPath):
    path = inPath + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_sat.tif'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)

for f in os.listdir(inPath2):
    path = inPath2 + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_mask.png'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)
print("finish!") 
 
mask_names = filter(lambda x: x.find('mask')!=-1, os.listdir(outPath))
# sat_names = filter(lambda x: x.find('sat')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in mask_names:
    path = outPath + f.strip()
    if not os.path.exists(path):
        continue;    
    img = cv2.imread(path,0)             
    if cv2.countNonZero(img) == 0:
       print(f+'Image is black')
       path2=f[:-9]
       os.remove(path)
       os.remove(outPath +path2 + "_sat.tif")

建筑物41897个样本

删除空白图像及对应图片

二、开始训练

train.py

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import cv2
import os
import numpy as np

from time import time
import datetime
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder

import matplotlib.pyplot as plt

SHAPE = (512,512)
ROOT = 'dataset/train/'
imagelist = filter(lambda x: x.find('sat')!=-1, os.listdir(ROOT))
trainlist = list(map(lambda x: x[:-8], imagelist))

NAME = 'log01_Dink101'
BATCHSIZE_PER_CARD = 16


solver = MyFrame(DinkNet34, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

dataset = ImageFolder(trainlist, ROOT)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batchsize,
    shuffle=True)


mylog = open('logs/'+NAME+'.log','w')
tic = time()
time1 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
no_optim = 0
total_epoch = 200
train_epoch_best_loss = 150
#draw=0
Loss_list = []

print('start!!')
for epoch in range(1, total_epoch + 1):
    data_loader_iter = iter(data_loader)
    train_epoch_loss = 0
    for img, mask in data_loader_iter:
        solver.set_input(img, mask)
        train_loss = solver.optimize()
        train_epoch_loss += train_loss
    train_epoch_loss /= len(data_loader_iter)
    Loss_list.append(train_epoch_loss)

    mylog.write( '********')
    #mylog.write('epoch:' + str(epoch) + '    time:'+ str(int(time()-tic)))
    mylog.write('epoch:' + str(epoch) + '    time:'+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    mylog.write('    train_loss:' + str(train_epoch_loss))
    mylog.write('    SHAPE:' + SHAPE.__str__())
    mylog.write('\n')
    print ('********')
    #print ('epoch:',epoch,'    time:',int(time()-tic))
    print ('epoch:',epoch,'    time:',datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print ('train_loss:',train_epoch_loss)
    print ('SHAPE:',SHAPE)
    #draw=draw+1

    if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss
        solver.save('weights/'+NAME+'.th')
    if no_optim > 6:
        mylog.write('early stop at %d epoch' % epoch)
        print ('early stop at %d epoch' % epoch)
        solver.save('weights/'+NAME+'_earlystop_%d.th'% epoch)
        #break
    if epoch%5==0:
        solver.save('weights/'+NAME+'_five_%d.th'% epoch)
    
    if no_optim > 3:
        if solver.old_lr < 5e-7:
            mylog.write('olver.old_lr < 5e-7 at %d epoch' % epoch)
            #break
        solver.load('weights/'+NAME+'.th')
        solver.update_lr(5.0, factor = True, mylog = mylog)
    mylog.flush()

mylog.write('Finish!')
print ('Finish!')

time2 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("begin train!")
# 打印按指定格式排版的时间
print(time1)  
print("finish train!")
# 打印按指定格式排版的时间
print(time2)  

mylog.close()
x1 = range(0, len(Loss_list))    
y1 = Loss_list
plt.plot(x1, y1, 'o-')
plt.title('Model loss vs. epoches')
plt.ylabel('Model loss')
plt.savefig("model_loss.jpg")
plt.show()

三、预测

有一点问题 懒得搞了  218和219行代码 第一次不注释 第二次注释

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔


from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU

from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool

BATCHSIZE_PER_CARD = 32

class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        
    def test_one_img_from_path(self, path, evalmode = True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)
        elif batchsize >= 4:
            return self.test_one_img_from_path_2(path)
        elif batchsize >= 2:
            return self.test_one_img_from_path_4(path)

    def test_one_img_from_path_8(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2

    def test_one_img_from_path_4(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2
    
    def test_one_img_from_path_2(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = img3.transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        img6 = img4.transpose(0,3,1,2)
        img6 = np.array(img6, np.float32)/255.0 * 3.2 -1.6
        img6 = V(torch.Tensor(img6).cuda())
        
        maska = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        maskb = self.net.forward(img6).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3
    
    def test_one_img_from_path_1(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        
        mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        mask1 = mask[:4] + mask[4:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3

    def load(self, path):
     #   new_state_dict = OrderedDict()
      #  for key, value in torch.load(path).items():
       #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'

def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def savetrainList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt_train.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))

def trainsample(train_dir,test_num):
    save_valid_dir='./dataset/valid_train/'
    os.mkdir(save_valid_dir)
    save_real_dir='./dataset/real_train/'
    os.mkdir(save_real_dir)
    file=open('./dataset/gt_train.txt','w')
    file.close
    tif_list = filter(lambda x: x.find('sat')!=-1, os.listdir(train_dir))
    random.seed(0)
    train_tif_list = random.sample(list(tif_list), test_num)
    # train_png_list=[]
    for tif_name in train_tif_list:
        shutil.copy(train_dir+tif_name, save_valid_dir+tif_name)
        png_name=tif_name[:-7]+'mask.png'
        shutil.copy(train_dir+png_name, save_real_dir+png_name)
        # train_png_list.append(png_name)
    print("随机抽取训练IOU样本数为"+str(test_num))
    
    
    
    # # print (sample)
    # for name in train_name:
    #     shutil.copy(file_dir+name, save_dir+name)
    # return train_name

train_n=1 #n=0 表示 不运行train相关的内容,n=1运行相关内容
test_n=1 #0为不运行test 1为运行test
draw_n=1#是否将train和test绘制在一张图片上

print("开始运行!")
# source = 'dataset/test/'

solver = TTAFrame(DinkNet34)
# solver = TTAFrame(DinkNet50)
weight_dir      =  "./weights/"

weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))


save_valid_dir='./dataset/valid_train/'
test_num=len(os.listdir('./dataset/valid/'))
# if not os.path.exists(save_valid_dir):
#    trainsample('./dataset/train/',test_num)

mylog = open('submits/count_low_pic.log','w')

if test_n==1:
    source = 'dataset/valid/'
    test_valid = os.listdir(source)
    test_num=len(os.listdir('./dataset/valid/'))

if train_n==1:    
    train_source = 'dataset/valid_train/'
    train_valid=os.listdir(train_source)
    #获取训练样本的函数,如果已经有训练样本文件夹就注释掉
    save_valid_dir='./dataset/valid_train/'

for weight_name in weight_list:
    weight_path=os.path.join(weight_dir,weight_name )
    # solver.load('weights/log01_Dink34.th')
    solver.load(weight_path)
    tic = time()
    tar=os.path.join('./submits/',weight_name[:-3])
    if not os.path.exists(tar):
        os.mkdir(tar)
    if test_n==1:
        target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
        lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
        higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
        if not os.path.exists(target):
            os.mkdir(target)
        if not os.path.exists(lower_iou):  
            os.mkdir(lower_iou)
        if not os.path.exists(higher_iou):  
            os.mkdir(higher_iou)

        test_pre_img_dir=os.path.join(target,'test_pre_img/')
        if not os.path.exists(test_pre_img_dir):
            os.mkdir(test_pre_img_dir)
            print('开始测试样本预测,在test_iou/test_pre_img文件夹')
            for i,name in enumerate(test_valid):
                # if i%10 == 0:
                # print (i/10, '    ','%.2f'%(time()-tic))
                # print (name)
                mask = solver.test_one_img_from_path(source+name)
                mask[mask>4.0] = 255
                mask[mask<=4.0] = 0
                mask = np.concatenate([mask[:,:,None],mask[:,:,None],mask[:,:,None]],axis=2)
                cv2.imwrite(test_pre_img_dir+name[:-7]+'mask.png',mask.astype(np.uint8))
            print('完成测试样本预测,在test_iou/test_pre_img文件夹')

       
   

    if train_n==1:    
        target_train = os.path.join('./submits/',weight_name[:-3]+'/'+'train_iou/')
        lower_train_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_train_iou/')
        higher_train_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_train_iou/')
        if not os.path.exists(target_train):
            os.mkdir(target_train)
        if not os.path.exists(lower_train_iou):
            os.mkdir(lower_train_iou)
        if not os.path.exists(higher_train_iou):
            os.mkdir(higher_train_iou)
        train_pre_img_dir=os.path.join(target_train,'train_pre_img/')
        if not os.path.exists(train_pre_img_dir):
            os.mkdir(train_pre_img_dir)
            print('开始训练样本预测,在train_iou/train_pre_img文件夹')
            for i,name in enumerate(train_valid):
            # if i%10 == 0:
            # print (i/10, '    ','%.2f'%(time()-tic))
            # print (name)
                mask = solver.test_one_img_from_path(train_source+name)
                mask[mask>4.0] = 255
                mask[mask<=4.0] = 0
                mask = np.concatenate([mask[:,:,None],mask[:,:,None],mask[:,:,None]],axis=2)
                cv2.imwrite(train_pre_img_dir+name[:-7]+'mask.png',mask.astype(np.uint8))
            print('完成训练样本预测,在train_iou/train_pre_img文件夹')

    if draw_n==1:  
        iou_pic=os.path.join('./submits/','train_test_mioupic/')
        pa_pic=os.path.join('./submits/','train_test_mpapic/')
        if not os.path.exists(iou_pic):
            os.mkdir(iou_pic)
        if not os.path.exists(pa_pic):
            os.mkdir(pa_pic)

    
    
    #wtn:精度计算
    miou_mode       = 2
    #------------------------------#
    #   分类个数+1、如2+1
    #------------------------------#
    num_classes     = 2
    #--------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    #--------------------------------------------#
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes    = ["nonwater","water"]
    #-------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    #-------------------------------------------------------#
    data_path  = './dataset/'
    data_train_path='./dataset/'

    if test_n==1:
        f=open("./dataset/gt.txt", 'w')
        gt_dir      = os.path.join(data_path, "real/")
        pred_dir    = test_pre_img_dir
        path_list = os.listdir(gt_dir)
        path_list.sort()
        dirList(gt_dir,path_list)
        saveList(path_list)
        image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() 

    if train_n==1:
        f=open("./dataset/gt_train.txt", 'w')
        gt_train_dir=os.path.join(data_train_path, "real_train/")
        pred_train_dir = train_pre_img_dir
        path_train_list = os.listdir(gt_train_dir)
        path_train_list.sort()
        dirList(gt_train_dir,path_train_list)
        savetrainList(path_train_list)
        image_train_ids=open(os.path.join(data_path, "gt_train.txt"),'r').read().splitlines()
   
    train_mIou=[]
    train_mPA=[]
    test_mIou=[]
    test_mPA=[]

    if miou_mode == 0 or miou_mode == 2:
        
        mylog.write(str(weight_name[:-3]))

        print(weight_name +"  Get miou.")

        if train_n==1:

            print('计算训练miou')
            train_mIou,train_mPA,train_miou,train_mpa=compute_mIoU(gt_train_dir, pred_train_dir, image_train_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
            mylog.write('  train_mIoU:  '+str(train_miou))
            mylog.write('  train_mPA:  '+str(train_mpa))

        if test_n==1:

            print('计算测试miou')
            test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
            mylog.write('  test_mIoU:  '+str(test_miou))
            mylog.write('  test_mPA:  '+str(test_mpa))
      
        if draw_n==1:
            x=np.arange(len(train_mIou))
            plt.figure()
            plt.plot(x,train_mIou)
            plt.plot(x,test_mIou)
            plt.grid(True)

            y_major_locator=MultipleLocator(10)#把y轴的刻度间隔设置为10,并存在变量里
            ax = plt.gca()
            ax.yaxis.set_major_locator(y_major_locator)
            ax.set_ylim(0,100)

            plt.xlabel('image')
            plt.ylabel('mIOU')
            plt.legend(['validation','test'],loc="lower right")
            plt.title(weight_name[:-3]+'  mIOU')
            plt.savefig(os.path.join(iou_pic, weight_name[:-3]+"_contra_miou.png"))

            plt.figure()
            plt.plot(x,train_mPA)
            plt.plot(x,test_mPA)
            plt.grid(True)
            y_major_locator=MultipleLocator(10)#把y轴的刻度间隔设置为10,并存在变量里
            ax = plt.gca()
            ax.yaxis.set_major_locator(y_major_locator)
            ax.set_ylim(0,100)
            plt.xlabel('image')
            plt.ylabel('mPA')
            plt.title(weight_name[:-3]+'  mPA')
            plt.legend(['validation','test'],loc="lower right")
            plt.savefig(os.path.join(pa_pic, weight_name[:-3]+"_contra_mpa.png"))

            plt.figure()
            plt.plot(x,train_mIou)
            plt.plot(x,test_mIou)
            plt.grid(True)
            plt.xlabel('image')
            plt.ylabel('mIOU')
            plt.legend(['validation','test'],loc="upper right")
            plt.title(weight_name[:-3]+'  mIOU')
            plt.savefig(os.path.join(iou_pic, 'a_'+weight_name[:-3]+"contra_miou.png"))

            plt.figure()
            plt.plot(x,train_mPA)
            plt.plot(x,test_mPA)
            plt.grid(True)
            plt.xlabel('image')
            plt.ylabel('mPA')
            plt.title(weight_name[:-3]+'  mPA')
            plt.legend(['validation','test'],loc="upper right")
            plt.savefig(os.path.join(pa_pic,'a_'+ weight_name[:-3]+"contra_mpa.png"))

        if test_n==1:
            count=0
            print('计算测试样本单张iou')
            count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count)  # 执行计算mIoU的函数
            mylog.write('  low-iou test picture num:  '+str(count))
            print(weight_name + "Get miou done.")

        if train_n==1:
            count2=0
            print('计算训练样本单张iou')
            count2=compute_IoU(gt_train_dir, pred_train_dir,  image_train_ids, num_classes, lower_train_iou,higher_train_iou,weight_name,count2)  # 执行计算mIoU的函数
            print(weight_name + "Get miou done.")
            print("low_train pictures: %d"%count2)
            mylog.write('  low-iou train picture num:  '+str(count2)+'\n')
            mylog.flush()

mylog.write('Finish!')
print ('Finish!')
mylog.close()

先膨胀预测再算iou

import torch
from torch.autograd import Variable as V

import cv2
import os
import shutil
from PIL import Image
from PIL import ImageFile
import numpy as np
import matplotlib.pyplot as plt

from osgeo import gdal,ogr,osr

from time import time
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU

from networks.dinknet import DinkNet34
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder

from PyQt5 import QtCore, QtGui
import sys
from PyQt5.QtWidgets import QMainWindow,QApplication,QWidget
from PyQt5.QtCore import QEventLoop, QTimer, QThread
BATCHSIZE_PER_CARD = 16
def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))

def DeleteShp(layer,count):
    for i in range(count):
        feature = layer.GetFeature(i) 
        code = feature.GetField('value')
        if(code==0):
            id = feature.GetFID()
            layer.DeleteFeature(int(id))

def GridToShp(input_path,Outshp_path):
    inraster = gdal.Open(input_path)
    im_data = inraster.GetRasterBand(1)    
    driver = ogr.GetDriverByName("ESRI Shapefile")
    if os.access(Outshp_path,os.F_OK):  
        driver.DeleteDataSource(Outshp_path)
    ds = driver.CreateDataSource(Outshp_path)  
    spatialref = osr.SpatialReference()
    # proj = osr.SpatialReference(wkt = inraster.GetProjection())
    # epsg = int(proj.GetAttrValue("AUTHORITY",1))  
    # spatialref.ImportFromEPSG(epsg) 
    spatialref.ImportFromWkt(inraster.GetProjection())  
    geomtype = ogr.wkbMultiPolygon  
  
    layer = ds.CreateLayer(Outshp_path[:-4],srs=spatialref,geom_type=geomtype) 
    layer.CreateField(ogr.FieldDefn('value',ogr.OFTReal))
    gdal.FPolygonize(im_data,im_data,layer,0,[],None)
    ds.SyncToDisk()
    ds.Destroy()
    
    ds = ogr.Open(Outshp_path,True)
    Layer = ds.GetLayer(0)
    Count = Layer.GetFeatureCount()
    DeleteShp(Layer,Count)
    ds.Destroy()
class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        
    def test_one_img_from_path(self, path, evalmode = True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)
        elif batchsize >= 4:
            return self.test_one_img_from_path_2(path)
        elif batchsize >= 2:
            return self.test_one_img_from_path_4(path)

    def test_one_img_from_path_8(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2

    def test_one_img_from_path_4(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.array(img1)[:,:,::-1]
        img4 = np.array(img2)[:,:,::-1]
        
        img1 = img1.transpose(0,3,1,2)
        img2 = img2.transpose(0,3,1,2)
        img3 = img3.transpose(0,3,1,2)
        img4 = img4.transpose(0,3,1,2)
        
        img1 = V(torch.Tensor(np.array(img1, np.float32)/255.0 * 3.2 -1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32)/255.0 * 3.2 -1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32)/255.0 * 3.2 -1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32)/255.0 * 3.2 -1.6).cuda())
        
        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,::-1] + maskc[:,:,::-1] + maskd[:,::-1,::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1,::-1]
        
        return mask2
    
    def test_one_img_from_path_2(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = img3.transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        img6 = img4.transpose(0,3,1,2)
        img6 = np.array(img6, np.float32)/255.0 * 3.2 -1.6
        img6 = V(torch.Tensor(img6).cuda())
        
        maska = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        maskb = self.net.forward(img6).squeeze().cpu().data.numpy()
        
        mask1 = maska + maskb[:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3
    
    def test_one_img_from_path_1(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        
        mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        mask1 = mask[:4] + mask[4:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3

    def load(self, path):
    #   new_state_dict = OrderedDict()
    #  for key, value in torch.load(path).items():
    #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        self.net.load_state_dict(torch.load(path))

#os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 指定第一块GPU可用

# config.gpu_options.per_process_gpu_memory_fraction = 0.7  # 程序最多只能占用指定gpu50%的显存,服务器上注释掉这句

ImageFile.LOAD_TRUNCATED_IMAGES = True

Image.MAX_IMAGE_PIXELS = None

slide_window = 512  # 大的滑动窗口
step_length = 256

# 1.膨胀图像
print('开始膨胀预测!')
source = 'dataset/sat_test/'
sat_path ="./dataset/sat_test/"
test_path ="./dataset/test/"
if os.path.exists(test_path):
        shutil.rmtree(test_path)  # 递归删除文件夹下的所有内容包扩文件夹本身
os.mkdir(test_path)
file_names = filter(lambda x: x.find('tif')!=-1, os.listdir(sat_path))
original='./Big_Image_Predict_Result/'
if os.path.exists(original):
        shutil.rmtree(original)  # 递归删除文件夹下的所有内容包扩文件夹本身
os.mkdir(original)


for file in file_names:
    img = os.path.join(sat_path, file)
    fname, ext = os.path.splitext(img)
    base_name = os.path.basename(fname)
    change_name=test_path+base_name+'_sat.tif'
    if not os.path.isfile(change_name):
    #if not os.path.isfile(filePath+'\\'+base_name[0]+'_mask.png'):
            print('原始影像改名成功:"'+file+'"')
            shutil.copy(sat_path+file,change_name)

test_list = os.listdir(test_path) 
# # print(file_list)

for file in test_list:
    Image_Path = test_path+file

    #print(Image_Path)
    #print(Mask_Path)

    image = Image.open(Image_Path)
    image_name = file[:-4]
    width = 6060  # 获取图像的宽
    height = 6060  # 获取图像的高

    right_fill = step_length - (width % step_length)
    bottom_fill = step_length - (height % step_length)

    width_path_number = int((width + right_fill) / step_length)  # 横向切成的小图的数量
    height_path_number = int((height + right_fill) / step_length)  # 纵向切成的小图的数量

    image = np.array(image)
    image=image[:6060,:6060,:]
    # image[:,:,3]=image[:,:,1]

    image = cv2.copyMakeBorder(image, top=0, bottom=bottom_fill, left=0, right=right_fill,
                            borderType=cv2.BORDER_CONSTANT, value=0)

    image = cv2.copyMakeBorder(image, top=step_length // 2, bottom=step_length // 2, left=step_length // 2,
                            right=step_length // 2,
                            borderType=cv2.BORDER_CONSTANT, value=0)  # 填充1/2步长的外边框
    # cv2.namedWindow('swell',cv2.WINDOW_NORMAL)
    # cv2.imshow('swell',image) 
    # cv2.waitKey(0)
    print('图像膨胀步骤完成!')


    # 2.将膨胀后的大图按照滑窗裁剪
    crop_source = './dataset/'
    #tar=os.path.join('./dataset/',file[:-8]+'/'+'Image_Crop_Result/')
    tar=os.path.join('./dataset/',file[:-8])
    #shutil.rmtree(r"C:\Users\Administrator\Desktop\DeepGlobe-Road-Extraction-link34\dataset\Image_Crop_Result")  # 递归删除文件夹下的所有内容包扩文件夹本身
    if os.path.exists(tar):
        shutil.rmtree(tar)  # 递归删除文件夹下的所有内容包扩文件夹本身
    os.mkdir(tar)
    target=os.path.join(tar,'Image_Crop_Result/')
    os.mkdir(target)
    image_crop_addr = target  # 图像裁剪后存储的文件夹
    image = Image.fromarray(image)  # 将图片格式从numpy转回PIL
    l = 0
    for j in range(height_path_number):
        for i in range(width_path_number):
            box = (i * step_length, j * step_length, i * step_length + slide_window, j * step_length + slide_window)
            small_image = image.crop(box)
            small_image.save(
                image_crop_addr + image_name[:-4] + '({},{})@{:09d}_sat.tif'.format(j, i, l), quality=95)
            l = l + 1

    print('膨胀后大图滑窗裁剪步骤完成!')
    

    # 3、对上面裁剪得到的小图进行推理
    #targ=os.path.join(sat_path,os.path.pardir)
    print('开始预测!')
    test=os.path.join(tar,'Image_Predict_Result/')
    if os.path.exists(test):
        shutil.rmtree(test)
    os.mkdir(test)
    # path =target+ '*.tif'
    # print(path)
    # expanded_images_crop = glob.glob(path)
    # print(expanded_images_crop)

    # # model = load_model("unet_cancer_2021-11-16__01_27_12.h5")  # 加载模型
    # # 对小图一次进行单独预测,再将预测图保存为彩色索引图像
    # for k in expanded_images_crop:
    
    val = os.listdir(target)
    #solver = TTAFrame(LinkNet34)
    solver = TTAFrame(DinkNet34)
    solver.load('weights/log01_Dink101_five_100.th')
    tic = time()
    for i,name in enumerate(val):
        if i%10 == 0:
            print(str(i/10)+'     %.2f'%(time()-tic))
        mask = solver.test_one_img_from_path(target+name)
        mask[mask>4.0] = 255
        mask[mask<=4.0] = 0
        mask = np.concatenate([mask[:,:,None],mask[:,:,None],mask[:,:,None]],axis=2)
        cv2.imwrite(test+name[:-7]+'mask.png',mask.astype(np.uint8))
    
    print('预测步骤完成!')
    

    # 4.将膨胀过的图裁剪回原来的大小
    recover=os.path.join(tar,'Image_Recover/')
    if os.path.exists(recover):
        shutil.rmtree(recover)  # 递归删除文件夹下的所有内容包扩文件夹本身
    os.mkdir(recover)

    val = os.listdir(test)
    for i,expanded_image in enumerate(val):
        img = Image.open(test+expanded_image)
        img_name = os.path.basename(test+expanded_image)
        box = (128, 128, 384, 384)
        original_image = img.crop(box)
        original_image.save(recover + img_name, quality=95)
    print('图像裁剪回原来大小步骤完成!')

    # 5、图片拼接
    IMAGES_PATH = recover  # 图片集地址
    IMAGES_FORMAT = ['.png']  # 图片格式
    IMAGE_SIZE =  256 # 每张小图片的大小
    original='./Big_Image_Predict_Result/'
    # if not os.path.exists(original):
    #     #shutil.rmtree(original)  # 递归删除文件夹下的所有内容包扩文件夹本身
    #     os.mkdir(original)

    # 获取图片集地址下的所有图片名称
    image_names = [name for name in os.listdir(IMAGES_PATH) for item in IMAGES_FORMAT if
                os.path.splitext(name)[1] == item]
    
    image_names.sort(key=lambda x:int(x[-18:-9]))  # 这句不能少,os.listdir得到的文件没有顺序,必须进行排序
    
    IMAGE_ROW = int(height_path_number)  # 图片间隔,也就是合并成一张图后,一共有几行
    IMAGE_COLUMN = int(width_path_number)  # 图片间隔,也就是合并成一张图后,一共有几列
    # 简单的对于参数的设定和实际图片集的大小进行数量判断
    if len(image_names) != IMAGE_ROW * IMAGE_COLUMN:
        raise ValueError("合成图片的参数和要求的数量不能匹配!")

    to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE, IMAGE_ROW * IMAGE_SIZE))  # 创建一个新图
    # 循环遍历,把每张图片按顺序粘贴到对应位置上
    for y in range(1, IMAGE_ROW + 1):
        for x in range(1, IMAGE_COLUMN + 1):
            #print(image_names[IMAGE_COLUMN * (y - 1) + x - 1])
            from_image = Image.open(IMAGES_PATH + image_names[IMAGE_COLUMN * (y - 1) + x - 1]).resize(
                (IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
            to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))

    # 拼接完的大图的右侧和下侧有多余填充上去的部分,应裁掉
    box2 = (0, 0, int(to_image.size[0] - right_fill / (width + right_fill) * to_image.size[0]),
            int(to_image.size[1] - bottom_fill / (height + bottom_fill) * to_image.size[1]))
    original_mask = to_image.crop(box2)

    original_mask.save(original + image_name[:-4] + "_mask.png", quality=95)  # 保存新图
    
    print('图像拼接步骤完成!')
    shutil.rmtree(tar)  # 递归删除文件夹下的所有内容包扩文件夹本身

print('********')
print("Finish膨胀预测!") 
print('********')

匹配地理坐标后转矢量 

 
from torch.autograd import Variable as V
 
 
import os
import shutil
from PIL import Image
 
 
from osgeo import gdal,ogr,osr
 
from time import time
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
 
 
BATCHSIZE_PER_CARD = 16
def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close
 
def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))
 
def DeleteShp(layer,count):
    for i in range(count):
        feature = layer.GetFeature(i) 
        code = feature.GetField('value')
        if(code==0):
            id = feature.GetFID()
            layer.DeleteFeature(int(id))
 
def GridToShp(input_path,Outshp_path):
    inraster = gdal.Open(input_path)
    im_data = inraster.GetRasterBand(1)    
    driver = ogr.GetDriverByName("ESRI Shapefile")
    if os.access(Outshp_path,os.F_OK):  
        driver.DeleteDataSource(Outshp_path)
    ds = driver.CreateDataSource(Outshp_path)  
    spatialref = osr.SpatialReference()
    # proj = osr.SpatialReference(wkt = inraster.GetProjection())
    # epsg = int(proj.GetAttrValue("AUTHORITY",1))  
    # spatialref.ImportFromEPSG(epsg) 
    spatialref.ImportFromWkt(inraster.GetProjection())  
    geomtype = ogr.wkbMultiPolygon  
  
    layer = ds.CreateLayer(Outshp_path[:-4],srs=spatialref,geom_type=geomtype) 
    layer.CreateField(ogr.FieldDefn('value',ogr.OFTReal))
    gdal.FPolygonize(im_data,im_data,layer,0,[],None)
    ds.SyncToDisk()
    ds.Destroy()
    
    ds = ogr.Open(Outshp_path,True)
    Layer = ds.GetLayer(0)
    Count = Layer.GetFeatureCount()
    DeleteShp(Layer,Count)
    ds.Destroy()
"    def transButtonclick(self):"
sat_path ="/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/sat_test/"
original='./Big_Image_Predict_Result/'
 
addcor_original='./dataset/Add_Coordinate/'
shp='./Final_Result_SHP/'
if not os.path.exists(addcor_original):
        os.mkdir(addcor_original)
if not os.path.exists(shp):
        os.mkdir(shp)
original_names = filter(lambda x: x.find('mask')!=-1, os.listdir(original))
for f in original_names:
    CorimgPath = sat_path+ f[:-9] + '.tif'
    if os.path.exists(CorimgPath):
        path = original + f.strip()
        print(str(path))
 
        in_ds = gdal.Open(CorimgPath)
        in_ds2= gdal.Open(path)
 
        bands_num = in_ds2.RasterCount
        png_image = Image.open(path)
        block_xsize = png_image.size[0]
        block_ysize = png_image.size[1]
 
        # 读取原图中的每个波段,通道数从1开始,默认前三波段
        in_band1 = in_ds2.GetRasterBand(1)
        in_band2 = in_ds2.GetRasterBand(2)
        in_band3 = in_ds2.GetRasterBand(3)
 
        gtif_driver = gdal.GetDriverByName("GTiff")  # 数据类型必须有,因为要计算需要多大内存空间,但是这儿是只有GTiff吗?
        filename = addcor_original + f[:-3] + 'tif'  # 文件名称
 
        out_band1 = in_band1.ReadAsArray(0, 0, block_xsize, block_ysize)
        out_band2 = in_band2.ReadAsArray(0, 0, block_xsize, block_ysize)
        out_band3 = in_band3.ReadAsArray(0, 0, block_xsize, block_ysize)
 
        # 获取原图的原点坐标信息
        ori_transform = in_ds.GetGeoTransform()
        out_ds = gtif_driver.Create(filename, block_xsize, block_ysize, 3, in_band1.DataType)  # 数据格式遵循原始图像
        out_ds.SetGeoTransform(ori_transform)
 
        # 设置SRS属性(投影信息)
        out_ds.SetProjection(in_ds.GetProjection())
 
        # 写入目标文件(如果波段数有更改,这儿也需要修改)
        out_ds.GetRasterBand(1).WriteArray(out_band1)
        out_ds.GetRasterBand(2).WriteArray(out_band2)
        out_ds.GetRasterBand(3).WriteArray(out_band3)
 
        # 将缓存写入磁盘,直接保存到了程序所在的文件夹
        # out_ds.FlushCache()
        print(str(filename))
        del out_ds
for img in os.listdir(addcor_original):
    input_path=addcor_original + img.strip()
    Outshp_path=shp + img[:-4] +'.shp'
    print(str(Outshp_path))
    GridToShp(input_path,Outshp_path)
 
print('********')
print('Finish匹配地理坐标后转矢量!') 
print('********')

裁剪BigData

import cv2
import os
 
# Cutting the input image to h*w blocks
 
 
outPath = "./dataset/tmp/"
inPath2 = "./dataset/Big_Image_Predict_Result/"
if not os.path.exists(outPath):
    os.mkdir(outPath) 
 
 
for f in os.listdir(inPath2):
    path = inPath2 + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-9]+'({},{})@{:04d}_mask.png'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)
print("finish!") 
 
 
 

删除real中没有的tmp 

import os
import cv2
# source = 'dataset/sat_train/'
real_path ="./dataset/real/"
pre_path ="./dataset/tmp/"

real_names = filter(lambda x: x.find('mask')!=-1, os.listdir(real_path))
pre_names = filter(lambda x: x.find('mask')!=-1, os.listdir(pre_path))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in pre_names:
    real_name = real_path + f.strip()
    if not os.path.exists(real_name):
        os.remove(pre_path + f.strip())
        print(pre_path + f.strip())
# for f in sat_names:
#     mask_path = tar + f.strip()[:-8] + "_mask.png"
#     if not os.path.exists(mask_path):
#         os.remove(tar + f.strip())
#         print(tar + f.strip())

预测新建文件夹3+1个

 

 

运行

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔


from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU

from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool

BATCHSIZE_PER_CARD = 32

# class TTAFrame():
#     def __init__(self, net):
#         self.net = net().cuda()
#         self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))

#     def load(self, path):
     #   new_state_dict = OrderedDict()
      #  for key, value in torch.load(path).items():
       #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        # self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'

def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def savetrainList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt_train.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))



print("开始运行!")
# source = 'dataset/test/'

# solver = TTAFrame(DinkNet34)
# solver = TTAFrame(DinkNet50)
weight_dir      =  "./weights/"

weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))


save_valid_dir='./dataset/valid_train/'
test_num=len(os.listdir('./dataset/valid/'))
# if not os.path.exists(save_valid_dir):
#    trainsample('./dataset/train/',test_num)

mylog = open('submits/count_low_pic.log','w')


source = 'dataset/valid/'
test_valid = os.listdir(source)
test_num=len(os.listdir('./dataset/valid/'))

for weight_name in weight_list:
    weight_path=os.path.join(weight_dir,weight_name )
    # solver.load('weights/log01_Dink34.th')
    # solver.load(weight_path)
    tic = time()
    tar=os.path.join('./submits/',weight_name[:-3])
    
    
    target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
    lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
    higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
    test_pre_img_dir=os.path.join(target,'test_pre_img/')


    
    
    #wtn:精度计算
    miou_mode       = 2
    #------------------------------#
    #   分类个数+1、如2+1
    #------------------------------#
    num_classes     = 2
    #--------------------------------------------#
    #   区分的种类,和json_to_dataset里面的一样
    #--------------------------------------------#
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes    = ["nonwater","water"]
    #-------------------------------------------------------#
    #   指向VOC数据集所在的文件夹
    #   默认指向根目录下的VOC数据集
    #-------------------------------------------------------#
    data_path  = './dataset/'
    data_train_path='./dataset/'

 
    f=open("./dataset/gt.txt", 'w')
    gt_dir      = os.path.join(data_path, "real/")
    pred_dir    = test_pre_img_dir
    path_list = os.listdir(gt_dir)
    path_list.sort()
    dirList(gt_dir,path_list)
    saveList(path_list)
    image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() 
   
    train_mIou=[]
    train_mPA=[]
    test_mIou=[]
    test_mPA=[]

    if miou_mode == 0 or miou_mode == 2:
        
        mylog.write(str(weight_name[:-3]))

        print(weight_name +"  Get miou.")

    

        

        print('计算测试miou')
        test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
        mylog.write('  test_mIoU:  '+str(test_miou))
        mylog.write('  test_mPA:  '+str(test_mpa))
        print('  test_mIoU:  '+str(test_miou))
      
      
        
        count=0
        print('计算测试样本单张iou')
        count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count)  # 执行计算mIoU的函数
        mylog.write('  low-iou test picture num:  '+str(count))
        print(weight_name + "Get miou done.")

       

mylog.write('Finish!')
print ('Finish!')
mylog.close()

三、分析结果 

修改75(2处)

import os
import shutil
data_path='./submits/log01_Dink101_five_75/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_75_excel.txt"),'r').read().splitlines()
valid_path='./dataset/valid/'
real_path='./dataset/real/'

iou_100=os.path.join(data_path,'iou_100/')
iou_80=os.path.join(data_path,'iou_80/')
iou_50=os.path.join(data_path,'iou_50/')
iou_30=os.path.join(data_path,'iou_30/')
if not os.path.exists(iou_100):
    os.mkdir(iou_100)
    os.mkdir(iou_80)
    os.mkdir(iou_50)
    os.mkdir(iou_30)
for n in data:
    name=n.split()[1]
    iou=float(n.split()[2])
    if iou>=80:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        
        shutil.copy(img_path,iou_100)
        file_name=os.path.join(iou_100,name+'.png')
        new_name=os.path.join(iou_100,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        
        shutil.copy(valid_name,iou_100)
        shutil.copy(real_name,iou_100)
        
        print(name,iou)
        continue
    elif iou>=50:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_80)
        file_name=os.path.join(iou_80,name+'.png')
        new_name=os.path.join(iou_80,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_80)
        shutil.copy(real_name,iou_80)
        print(name,iou)
        continue
    elif iou>=30:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_50)
        file_name=os.path.join(iou_50,name+'.png')
        a=os.path.exists(file_name)
        new_name=os.path.join(iou_50,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_50)
        shutil.copy(real_name,iou_50)
        print(name,iou)
        continue
    else:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_30)
        file_name=os.path.join(iou_30,name+'.png')
        new_name=os.path.join(iou_30,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_30)
        shutil.copy(real_name,iou_30)
        print(name,iou)
        continue


print('Finish')

把30以下删除,其余移到real和mask

删除像素小于10000的 


# The folder path of input and output
outPath = './submits/log01_Dink101_five_75/lower_iou/'
delPath='./dataset/valid/'
# outPath = './dataset/real_train/'
# delPath='./dataset/valid_train/'
#outPath = "C:/Users/Administrator/Desktop/t/"
for f in os.listdir(outPath):
    path = outPath + f.strip()
    if not os.path.exists(path):
        continue;    
    img = cv2.imread(path,0)
    whi = len(img[img==255])   
    if whi < 10000:
        print("Image is black")
        print(path)
        os.remove(path)
        path2=delPath +f.strip()[:-8] + "sat.tif"
        path3='./dataset/real/'+f.strip()
        if not os.path.exists(path2):
            print(path2+"noFile")
            continue;  
        os.remove(path2)
        os.remove(path3)

dsm 分析

import os
import shutil
data_path='/mnt/sdb1/fenghaixia/four/submits/log01_Dink101_five_100/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_100_excel.txt"),'r').read().splitlines()
valid_path='/mnt/sdb1/fenghaixia/four/dataset/valid_all/'
real_path='/mnt/sdb1/fenghaixia/four/dataset/real512/'
 
iou_100=os.path.join(data_path,'iou_100/')
iou_80=os.path.join(data_path,'iou_80/')
iou_50=os.path.join(data_path,'iou_50/')
iou_30=os.path.join(data_path,'iou_30/')
if not os.path.exists(iou_100):
    os.mkdir(iou_100)
    os.mkdir(iou_80)
    os.mkdir(iou_50)
    os.mkdir(iou_30)
for n in data:
    name=n.split()[1]
    iou=float(n.split()[2])
    if iou>=80:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        
        shutil.copy(img_path,iou_100)
        file_name=os.path.join(iou_100,name+'.png')
        new_name=os.path.join(iou_100,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        
        shutil.copy(valid_name,iou_100)
        shutil.copy(real_name,iou_100)
        
        print(name,iou)
        continue
    elif iou>=50:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_80)
        file_name=os.path.join(iou_80,name+'.png')
        new_name=os.path.join(iou_80,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_80)
        shutil.copy(real_name,iou_80)
        print(name,iou)
        continue
    elif iou>=30:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_50)
        file_name=os.path.join(iou_50,name+'.png')
        a=os.path.exists(file_name)
        new_name=os.path.join(iou_50,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_50)
        shutil.copy(real_name,iou_50)
        print(name,iou)
        continue
    else:
        img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
        valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
        real_name=os.path.join(real_path,name[:-4]+'mask.png')
        shutil.copy(img_path,iou_30)
        file_name=os.path.join(iou_30,name+'.png')
        new_name=os.path.join(iou_30,name+'_'+str(iou)+'.png')
        os.rename(file_name,new_name)
        shutil.copy(valid_name,iou_30)
        shutil.copy(real_name,iou_30)
        print(name,iou)
        continue
 
 
print('Finish')
pyinstaller -D  -w act_deepl.py

猜你喜欢

转载自blog.csdn.net/weixin_61235989/article/details/129256925
今日推荐