监督分类:机器学习实现不带坐标系图像即普通图片的多分类

1.安装labelme,用于标注,对机器学习来说就是选取像素样本,这一步对监督分类来说是没法避免的,你想想深度学习还要画样本呢,只不过这里选取样本比较随意,一点都不麻烦。
pip install labelme
安装好后直接在命令行输入labelme按enter,接着工具就会弹出来了
labelme

2.获取样本
样本

放大看:
取任意形状都可以接受,每个类别的范围不包含其它类就行,这里我分了两个类别,请注意在弹出来填写类别的框中务必以0,1,2,3…这样从0开始按顺序给类别,主要是为了方便后面的处理,如果不照做会报错。
放大看
样本选取好后,点击那个保存按钮,图像目录下会自动生成一个.json文件。
产生文件

3.标签转换
参考链接:https://zhuanlan.zhihu.com/p/116023772

import os

json_folder = r"C:\Users\Administrator\Desktop\data\test"   #这个路径包含了图像和对应的json文件,就是上面截图的路径
#  获取文件夹内的文件名
FileNameList = os.listdir(json_folder)
#  激活labelme环境
os.system("activate labelme")
for i in range(len(FileNameList)):
    #  判断当前文件是否为json文件
    if(os.path.splitext(FileNameList[i])[1] == ".json"):
        json_file = json_folder + "\\" + FileNameList[i]
        #  将该json文件转为png
        os.system("labelme_json_to_dataset " + json_file)

运行以后会产生一个文件夹
转换
里面有需要用到的一些图片
用到

这一步是利用了labelme自带的转换工具生成了需要的数据,包括图像、标签等,后面分类也是在转出的基础上做的,如果觉得会有图像质量损失的问题,你们自己把预测图换一下就好了,看完代码应该能明白什么意思,其实效果差不多的,我对比了,不用另外麻烦去改代码了。

4.训练预测代码
这里用到了配置文件的方式读取数据,下面会给出配置文件

# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import copy
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
from collections import Counter
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from skimage import morphology, filters
import pickle
# import numba
# from numba import jit

import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian


def getValues(img_path,mask_path):
    img = cv2.imread(img_path)
    mask = Image.open(mask_path)
    mask = mask.convert('P')
    mask = np.array(mask)
    statis = mask.flatten()
    lab_dict = Counter(statis)
    del lab_dict[0]
    class_list = {
    
    }
    for k,v in lab_dict.items():
        temp_list = []
        mask_temp = mask.copy()
        mask_temp[mask_temp<k] = 0
        mask_temp[mask_temp>k] = 0
        mask_temp[mask_temp==k] = 1
        masked = cv2.add(img, np.zeros(np.shape(img), dtype=np.uint8), mask=mask_temp)
        masked = cv2.cvtColor(masked, cv2.COLOR_BGR2RGB)
        cv2.imwrite(str(k)+'.png',masked)
        masked = Image.fromarray(masked.astype('uint8'))
        a, b = masked.size
        for i in range(a):
            for j in range(b):
                pixel = masked.getpixel((i,j))
                if pixel != ((0,0,0)):
                    temp_list.append(pixel)
        class_list[k] = np.array(temp_list)

    return class_list


def svm_train(class_list, img_arr, model_path):
    array_num = len(class_list)
    RGB_arr = np.array([[0,0,0]])
    label= np.array([])
    count = 0

    class_final = {
    
    }
    for i in sorted(class_list):
        RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
        array_l = class_list[i].shape[0]
        label = np.append(label, count * np.ones(array_l))
        class_final[i] = count
        count += 1
    RGB_arr = np.delete(RGB_arr,0,0)

    if os.path.exists(model_path):
        pass
    else:
        rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)

        rf.fit(RGB_arr, label)
        # svc.fit(RGB_arr,label)
        with open(model_path, 'wb') as f:
            pickle.dump(rf, f)

    
    return array_num, class_final

def get_model(model_path):
    with open(model_path, 'rb') as f:
        svc = pickle.load(f)
    return svc

def svm_predict(svc, img_arr, array_num, outPath):
    temp = copy.copy(img_arr)
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    predict = svc.predict(img_reshape)
    for j in range(array_num):
        lake_bool = predict == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        try:
            lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
            lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
            img_arr[lake_bool_4d] = np.float(j)
        except:
            lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
            lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
            img_arr[lake_bool_4d] = np.float(j)

    # crf_deal = crf(temp, img_arr[:,:,0])
    # img_arr = crf_deal.transpose((1,0))
    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    # write_img(outPath, im_proj, im_geotrans, img_arr)
    return img_arr


def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def write_img_(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset



def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    # print(a, b)
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    # x = d-c
    # if (x==0).any():
    #     t = 0
    # else:
    t = a + (bands[:, :] - c) * (b - a) / (d - c)
    t[t < a] = a
    t[t > b] = b
    out[:, :] = t
    return out

def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj


# @jit(nopython=True)
def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    tif1 = gdal.Open(tif1)
    # for i in range(heightNum):
    for i in tqdm(range(heightNum), desc='Processing'):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + str(i)+str(j)+".tif"
            outPath = os.path.join(tempPath,outName)

            if not os.path.exists(outPath):

                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)

                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                data1 = data1.transpose((2,1,0))
                svmData = svm_predict(model, data1, array_num, outPath)
                outTif.GetRasterBand(1).WriteArray(svmData)
    return 1

# @jit(nopython=True)
def partStretch(tif1,divisionSize,outStratchPath,tempPath):

    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName+str(i)+str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            for k in range(1):
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)

def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
           drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    return drv_list

def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'ESRI Shapefile'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]


def crf(inimage,img_anno):    # inimage为原图    img_anno为预测结果,我的预测结果是0,1,2,3这样,每个数字代表一个类别
    fn_im = inimage
    fn_anno = img_anno
    img = inimage
    anno_rgb = img_anno
    rgb = anno_rgb
    # print("=========>>", anno_rgb.shape)
    #rgb= np.argmax(anno_rgb[0],axis=0)
    # print("=======>>",rgb.shape)
    # print(np.max(rgb), np.min(rgb))
    anno_lbl=rgb
    # img = img[0]
    # img = img.transpose(1, 2, 0)
    colors, labels = np.unique(anno_lbl, return_inverse=True)
    colors = colors[1:]
    colorize = np.empty((len(colors), 3), np.uint8)
    colorize[:,0] = (colors & 0x0000FF)
    colorize[:,1] = (colors & 0x00FF00) >> 8
    colorize[:,2] = (colors & 0xFF0000) >> 16
    # n_labels = len(set(labels.flat))-1
    n_labels = len(set(labels.flat))   #这里我把减1去掉了,因为我的所有数字都代表一个类别,没有背景
    if n_labels <= 1:
        return rgb
    use_2d = False
    if use_2d:
        img = img.astype(int)
        d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
        d.setUnaryEnergy(U)
        d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,    #1.CONST_KERNEL  2.DIAG_KERNEL (the default)  3.FULL_KERNEL
                            normalization=dcrf.NORMALIZE_SYMMETRIC)  #1.NO_NORMALIZATION  2.NORMALIZE_BEFORE 3.NORMALIZE_AFTER 4.NORMALIZE_SYMMETRIC (the default)
        img = counts = np.copy(np.array(img,dtype = np.uint8),order='C')
        d.addPairwiseBilateral(sxy=(80,80), srgb=(13, 13, 13), rgbim=img,
                            compat=10,
                            kernel=dcrf.CONST_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

    else:
        #这部分比上面的效果好点,建议用这个
        # Example using the DenseCRF class and the util functions
        d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)

        # get unary potentials (neg log probability)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)  #zero_unsure=False 0不是背景而是一个类别,所以False
        d.setUnaryEnergy(U)

        # This creates the color-independent features and then add them to the CRF
        feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
        d.addPairwiseEnergy(feats, compat=3,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

        # This creates the color-dependent features and then add them to the CRF
        feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                        img=img, chdim=2)
        d.addPairwiseEnergy(feats, compat=10,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

    Q = d.inference(20)


# Find out the most probable class for each pixel.
    MAP = np.argmax(Q, axis=0)

    return MAP.reshape(img.shape[:2])

def remove_and_deal(img_array, hole, obj):
    # ret, binary = cv2.threshold(img_array, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    binary = img_array
    binary = binary.astype(bool)
    binary = morphology.remove_small_holes(binary, min_size=hole, connectivity=8)
    binary = morphology.remove_small_objects(binary, min_size=obj, connectivity=8)
    binary = binary + 0
    binary = np.uint8(binary)
    return binary

def cls_deal(class_path):
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(class_path)

    binary_0 = copy.copy(im_data)
    binary_0[binary_0==0] = 100
    binary_0[binary_0 < 10] = 0
    binary_0[binary_0==100] = 1
    binary_0 = remove_and_deal(binary_0, 2000, 500)
    temp0 = os.path.join(temp_path, '0.tif')
    write_img(temp0, im_proj, im_geotrans, binary_0)

    binary_1 = copy.copy(im_data)
    binary_1[binary_1 == 1] = 100
    binary_1[binary_1 < 10] = 0
    binary_1[binary_1 == 100] = 1
    binary_1 = remove_and_deal(binary_1, 2000, 500)
    temp1 = os.path.join(temp_path, '1.tif')
    write_img(temp1, im_proj, im_geotrans, binary_1)

    binary_2 = copy.copy(im_data)
    binary_2[binary_2 == 2] = 100
    binary_2[binary_2 < 10] = 0
    binary_2[binary_2 == 100] = 1
    binary_2 = remove_and_deal(binary_2, 2000, 500)
    temp2 = os.path.join(temp_path, '2.tif')
    write_img(temp2, im_proj, im_geotrans, binary_2)

    binary_3 = copy.copy(im_data)
    binary_3[binary_3 == 3] = 100
    binary_3[binary_3 < 10] = 0
    binary_3[binary_3 == 100] = 1
    binary_3 = remove_and_deal(binary_3, 2000, 500)
    temp3 = os.path.join(temp_path, '3.tif')
    write_img(temp3, im_proj, im_geotrans, binary_3)


if __name__ == '__main__':
    config_file='config_order.txt'
    dirs=[]
    for line in open(config_file):
        dirs.append(line.split()[0])
    
    data_image = dirs[0]
    data_image = data_image.replace('\\','/')

    mask_path = dirs[1]
    mask_path = mask_path.replace('\\','/')

    task_image = dirs[2]
    task_image = task_image.replace('\\','/')
    
    result_path = dirs[3]
    result_path = result_path.replace('\\','/')
    
    temp_path = dirs[4]
    temp_path = temp_path.replace('\\','/')

    time1 = time.time()
    print('Start ...')

    class_list = getValues(data_image,mask_path)
        
    print('Train model ...')
    model_path = os.path.join(temp_path, 'model.pickle')
    num, class_final = svm_train(class_list,data_image,model_path)
    svm = get_model(model_path)
    
    slice_path = os.path.join(temp_path, 'slice_temp')
    if os.path.exists(slice_path):
        pass
    else:
        os.mkdir(slice_path)

    print('Predict task area ...')
    partDivisionForBoundary(svm,num,task_image,1000,slice_path)
    raster_path = os.path.join(temp_path, 'class_raster.tif')
    partStretch(task_image,1000,raster_path,slice_path)

    im_proj, im_geotrans, im_width, im_height, im_data = read_img(raster_path)
    im_proj, im_geotrans, im_width, im_height, im_data2 = read_img(task_image)
    im_data = im_data.transpose((1,0))
    im_data2 = im_data2.transpose((2,1,0))
    crf_deal = crf(im_data2, im_data)
    crf_deal = crf_deal.transpose((1,0))
    raster_path = os.path.join(temp_path, 'class_raster2.tif')
    write_img_(raster_path, im_proj, im_geotrans, crf_deal)

    time2 = time.time()
    print((time2-time1)/3600)

配置文件名字config_order.txt,内容如下:

C:\Users\Administrator\Desktop\data\test2\fixed_json\img.png   #样本图
C:\Users\Administrator\Desktop\data\test2\fixed_json\label.png  #样本图上选取的标签,上面已经生成了
C:\Users\Administrator\Desktop\data\test2\fixed_json\img.png   #需要预测的图,考虑到可能图很大,所以选取样本的图可以是从大图上裁剪下来的,而预测的图可以是别的,反正模型是根据样本图训练生成的,这个要注意
C:\Users\Administrator\Desktop\data\test2\temp/t.png   # 这个路径本来是用来放结果的,现在没有用到,但是要有,代码自己改吧
C:\Users\Administrator\Desktop\data\test2\temp    #新建一个中间文件,里面放结果

运行结果:
结果
注:这里用的是随机森林的方法,开头import还导入了SVM需要的话自己替换就行了

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/113534829