Python, sklearn, svm, clasificación de datos de detección remota, ejemplos de código

@python, sklearn, svm, clasificación de datos de detección remota, ejemplos de código

Ejemplo de código de clasificación de datos de detección remota de Python_sklearn_svm

(1) Breve introducción del principio svm

Support Vector Machine (SVM) es un término general para una serie de algoritmos de aprendizaje supervisados ​​que incluyen Clasificación, Regresión y Detección de valores atípicos. Para la clasificación, SVM se usó originalmente para resolver el problema de clasificación binaria, y el problema de clasificación múltiple se puede resolver construyendo múltiples clasificadores SVM. SVM tiene dos características principales: 1. Buscar el límite de clasificación óptimo, es decir, resolver el hiperplano de separación que puede dividir correctamente el conjunto de datos de entrenamiento y tiene el intervalo geométrico más grande, que es la idea básica de SVM; 2. La transformación de dimensionalidad basada en la función del núcleo, es decir A través de la transformación de características de la función del núcleo, el conjunto de datos original inseparable lineal se transforma mediante una transformación que aumenta la dimensión para hacerlo linealmente separable. Por lo tanto, el proceso central de SVM es la función del núcleo y la selección de parámetros.

(2) análisis del entorno de implementación de svm

Establezca el formato compatible con el código de salida chino y las funciones de biblioteca referenciadas, las funciones de biblioteca utilizadas para la evaluación de precisión y la optimización de parámetros svm.
Muestra algunos a continuación 内联代码片.

 -*- coding: utf-8 -*-
#用于精度评价
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
#numpy引用
import numpy as np
#记录运行时间
import datetime
#文件路径操作
import os
#svm and best parameter select using grid search method
from sklearn import svm
from sklearn.model_selection import GridSearchCV
#scale the data to 0-1  用于数据归一化
from sklearn import preprocessing

(3) Optimización de los parámetros de la función svm

Hay dos métodos comunes para la realización de la optimización de parámetros SVM, uno es el método de búsqueda de cuadrícula (en este artículo) y el otro es usar la herramienta libsvm para lograr la validación cruzada (más adelante escrito, interesado puede dejar un mensaje)

def grid_find(train_data_x,train_data_y):
     # 10 is often helpful. Using a basis of 2, a finer.tuning can be achieved but at a much higher cost.
     # logspace(a,b,N),base默认=10,把10的a次方到10的b次方区间分成N份。
	C_range = np.logspace(-5, 9, 8, base=2) 
    # 如:C_range = 1/64,1/8,1/2,2,8,32,128,512
    gamma_range = np.logspace(-15, 3, 10, base=2)
    # 选择linear线性核函数和rbf核函数
    parameters = {'kernel': ('linear', 'rbf'), 'C': C_range, 'gamma': gamma_range}
    svr = svm.SVC()
     # n_jobs表示并行运算量,可加快程序运行结果。
     # 此处选择5折交叉验证,10折交叉验证也是常用的。
    clf = GridSearchCV(svr, parameters, cv=5, n_jobs=4)  
    # 进行模型训练
    clf.fit(train_data_x, train_data_y)
    print('最优c,g参数为:{0}'.format(clf.best_params_))
    # 返回最优模型结果
    svm_model = clf.best_estimator_
    return svm_model

Más sobre el método de búsqueda de cuadrícula:

(4) Función de lectura de datos de escritura (lea los archivos de entrenamiento y prueba en formato txt)

La primera es la función de leer los datos de entrenamiento y los datos de prueba en formato txt.
La captura de pantalla de los datos es la siguiente: entre ellas, las primeras 6 columnas de datos representan los valores grises de 6 bandas extraídas del área de interés de la imagen de detección remota (roi), y la última columna representa la etiqueta de la categoría de datos.
Entre ellas, las primeras 6 columnas de datos representan los valores grises de las 6 bandas extraídas del área de interés de la imagen de detección remota (roi), y la última columna representa la etiqueta de la categoría de datos.
El código es el siguiente, solo ingrese la ruta del archivo:

def open_txt_film(filepath):
    # open the film
    if os.path.exists(filepath):
        with open(filepath, mode='r') as f:
            train_data_str = np.loadtxt(f, delimiter=' ')
            print('训练(以及测试)数据的行列数为{}'.format(train_data_str.shape))
            return train_data_str
    else:
        print('输入txt文件路径错误,请重新输入文件路径')

(5) Preparación de la función de predicción del modelo svm

Modelo de entrada y datos de prueba, evaluación de precisión de salida (incluyendo matriz de confusión, precisión de dibujo, etc.).

def model_process(svm_model, test_data_x, test_data_y):
    p_lable = svm_model.predict(test_data_x)
    # 精确度为 生产者精度  召回率为 用户精度
    print('总体精度为 : {}'.format(accuracy_score(test_data_y, p_lable)))
    print('混淆矩阵为 :\n {}'.format(confusion_matrix(test_data_y, p_lable)))
    print('kappa系数为 :\n {}'.format(cohen_kappa_score(test_data_y, p_lable)))
    matric = confusion_matrix(test_data_y, p_lable)
    # output the accuracy of each category。由于类别标签是从1开始的,因此明确数据中最大值,即可知道有多少类
    for category in range(np.max(test_data_y)):
        # add 0.0 to keep the float type of output
        precise = (matric[category, category] + 0.0) / np.sum(matric[category, :])
        recall = (matric[category, category] + 0.0) / np.sum(matric[:, category])
        f1_score = 2 * (precise * recall) / (recall + precise)
        print(
            '类别{}的生产者、制图(recall)精度为{:.4}  用户(precision)精度为{:.4}  F1 score 为{:.4} '.format(category + 1, precise, recall, f1_score))                                 

(6) Escritura de la función principal

La función principal es la principal responsable de: leer datos, preprocesar datos y optimización de parámetros, capacitación y predicción de modelos.
Para diferentes conjuntos de datos, cada vez que use, solo necesita modificar la ruta de los datos de entrenamiento y prueba.

def main():
    # read the train data from txt film
    train_file_path = r'E:\CSDN\data1\train.txt'
    train_data = open_txt_film(train_file_path)
    # read the predict data from txt film
    test_file_path = r'E:\CSDN\data1\test.txt'
    test_data = open_txt_film(test_file_path)
     # data normalization for svm training and testing dataset
    scaler = preprocessing.MinMaxScaler().fit(train_data[:, :-1])
    train_data[:, :-1] = scaler.transform(train_data[:, :-1])
    # keep the same scale of the train data
    test_data[:, :-1] = scaler.transform(test_data[:, :-1])

    # conversion the type of data,and the label's dimension to 1-d
    train_data_y = train_data[:, -1:].astype('int')
    train_data_y = train_data_y.reshape(len(train_data_y))
    train_data_x = train_data[:, :-1]
    # 取出测试数据灰度值和标签值,并将2维标签转为1维
    test_data_x = test_data[:, :-1]
    test_data_y = test_data[:, -1:].astype('int')
    test_data_y = test_data_y.reshape(len(test_data_y))
    model = grid_find(train_data_x,train_data_y)
    # 模型预测
    model_process(model, test_data_x, test_data_y)

(7) Llamar a la función principal

Se han agregado algunas líneas de código aquí para registrar el tiempo de ejecución del programa.

if __name__ == "__main__":
    # remember the beginning time of the program
    start_time = datetime.datetime.now()
    print("start...%s" % start_time)

    main()

    # record the running time of program with the unit of minutes
    end_time = datetime.datetime.now()
    last_time = (end_time - start_time).seconds / 60
    print("The program is last %s" % last_time + " minutes")
    # print("The program is last {} seconds".format(last_time))

(8) Descargar la dirección de datos de entrenamiento y ejemplos de datos de prueba

Los datos están en el repositorio github del autor, un total de dos archivos (train.txt y test.txt).
[Enlace de descarga]: (https://github.com/sunzhihu123/sunzhihu123.github.io)
Haga clic en descargar debajo del almacén, como se muestra en la figura:
descargar github

== Las ventajas de este artículo: solo hay dos entradas, una es la ruta de los datos de entrenamiento y la otra es la ruta de los datos de prueba, fácil de comenzar; y tomando como ejemplo los datos de imágenes de detección remota. Además, github cargará el código fuente en su conjunto ~



Publicado un artículo original · elogiado 7 · visitas 357

Supongo que te gusta

Origin blog.csdn.net/qq_36803951/article/details/105590046
Recomendado
Clasificación