Minería de datos Java - Implementación del algoritmo KNN

1. Conocimiento previo del algoritmo KNN

k-Vecino más cercano (kNN, k-NearestNeighbor) es seleccionar los k vecinos más cercanos desde el punto de datos de entrada en el conjunto de entrenamiento, y usar la categoría con la mayor cantidad de ocurrencias (regla de votación máxima) entre los k vecinos como la categoría del punto de datos.

La clasificación es una tarea muy importante en la minería de datos. El propósito de la clasificación es aprender una función de clasificación o un modelo de clasificación (también conocido como clasificador), que puede asignar elementos de datos en la base de datos a una determinada categoría en un tipo determinado. La clasificación se puede utilizar para la predicción. El propósito de la previsión es derivar automáticamente una descripción de la tendencia de los datos dados a partir de los registros de datos históricos, de modo que se puedan predecir los datos futuros. Un método de pronóstico de uso común en las estadísticas es la regresión. La clasificación en minería de datos y el método de regresión en estadística son un par de conceptos interrelacionados pero diferentes. Generalmente, el resultado de la clasificación es un valor de categoría discreta, mientras que el resultado de la regresión es un valor continuo.

Similitud : Dada una base de datos D={t1,t2,…,tn} y un conjunto de clases C={C1,C2,…,Cm}. Para cualquier tupla ti={ti1,ti2,…,tik}∈D, si existe un Cj∈C tal que sim(ti,Cj)≥sim(ti,Cp), existe Cp∈C, Cp≠Cj, Entonces ti se asigna a la clase Cj, donde sim(ti,Cj) se denomina similitud. En el cálculo real, la distancia se usa a menudo para representar, cuanto más cercana es la distancia, mayor es la similitud, y cuanto más lejana es la distancia, menor es la similitud.

Para calcular la similitud, primero necesitamos obtener el vector que representa cada clase. Hay muchos métodos de cálculo, por ejemplo, el vector que representa cada clase se puede representar calculando el centro de cada clase. Además, en el reconocimiento de patrones, se utiliza una imagen predefinida para representar cada clase, y la clasificación consiste en comparar las muestras a clasificar con la imagen predefinida.

2. La idea básica del algoritmo KNN

La idea del algoritmo KNN es relativamente simple. Asumiendo que cada clase contiene múltiples datos de entrenamiento, y cada dato de entrenamiento tiene una etiqueta de categoría única, la idea principal del algoritmo KNN es calcular la distancia entre cada dato de entrenamiento y la tupla a clasificar, y tomar la distancia más cercana a la tupla a clasificar Los k datos de entrenamiento de los k datos, cuya categoría de datos de entrenamiento es mayoritaria entre los k datos, la tupla a clasificar pertenece a qué categoría.

3. Ejemplos de algoritmo KNN y minería de reglas de asociación fuerte

Ejemplo de algoritmo KNN
inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí

4. Proceso de implementación del algoritmo KNN

Contenido del experimento
Hay 14 estudiantes en una clase cuya altura y grado han sido registrados. El nuevo estudiante Yi Chang mide 174 cm de alto y cuál es su grado. Utilice el algoritmo knn para el reconocimiento de clasificación, donde k=5.
inserte la descripción de la imagen aquí

Ideas para experimentos
(1) Defina la clase de estudiante Student, defina atributos como el nombre, la altura y la calificación en la clase de estudiante, y use la anotación @Data en la que se basa lombok para inyectar los métodos get y set de la clase Student. Defina el conjunto de datos inicial, defina 14 clases de estudiantes de entidades y agregue las 14 clases de estudiantes de entidades al conjunto de datos inicial dataList.
(2) Llamar al método initData() para inicializar el conjunto de datos, definir una clase de Estudiante stuV0 e instanciar su nombre y altura como entrada, llamar al método Knn() para obtener el objeto de la clase de Estudiante estudiante con calificaciones y llevar a cabo el objeto salida del estudiante.
(3) Dentro del cuerpo del método Knn(), los primeros 5 elementos del conjunto de datos se agregan inicialmente a la colección categoryList. La colección categoryList se usa para almacenar los k estudiantes más cercanos a stuV0, y solo los primeros 5 elementos de los datos conjunto se almacenan inicialmente. Recorra el conjunto de datos dataList, calcule la distancia v0Tod entre stuV0 y cada elemento a partir del elemento 6 del conjunto de datos y llame al método getCalculate() para calcular la distancia entre stuV0 y el objeto de estudiante stuU en la colección categoryList, si stuU es la altura de stuV0 Si la distancia uToV0 es mayor que v0Tod, elimine stuU de la lista de categorías y agregue el elemento del conjunto de datos a la colección de la lista de categorías.
(4) Dentro del cuerpo del método getCalculate(), defina la variable maxHeight para almacenar la distancia más lejana entre stuV0 y el conjunto de categorías categoryList, y defina el objeto de clase Student resultStu para almacenar los estudiantes que se devolverán, es decir, los más lejanos. distancia entre stuV0 y el conjunto de categorías categoryList estudiante. Recorra la colección categoryList, si la distancia entre stuU y stuV0 es mayor que maxHeight, asigne v0ToU a maxHeight, asigne stuU a resultStu y finalmente devuelva el objeto de clase Student resultStu.
(5) Llame al método getCategoryStudent() para encontrar el rango del estudiante con la mayor proporción de la misma calificación en la lista de categorías y, finalmente, cree una instancia del atributo de rango de stuV0 con rank y devuelva stuV0.
(6) En el cuerpo del método getCategoryStudent(), recorra la lista de categorías para encontrar las calificaciones de los estudiantes con la mayor proporción de la misma calificación. De hecho, es para encontrar a los estudiantes con la calificación más alta, la calificación media y la nota corta, qué categoría tiene el mayor número de estudiantes y devolver la categoría con el mayor número de estudiantes.

Darse cuenta del código fuente

学生类
package com.data.mining.entity;

import lombok.Data;

@Data
public class Student {
    
    
    private String name;
    private int height;
    private String rank;

    public Student(){
    
    }

    public Student(String n, int h){
    
    
        name = n;
        height = h;
    }

    public Student(String n, int h, String r){
    
    
        name = n;
        height = h;
        rank = r;
    }
}

KNN算法实现代码
package com.data.mining.main;

import com.data.mining.entity.Student;

import java.util.ArrayList;
import java.util.List;

public class Knn {
    
    
    //定义初始数据集
    public static List<Student> dataList = new ArrayList<>();

    public static void main(String[] args) {
    
    
        initData();
        Student stuV0 = new Student("易昌", 174);
        Student student = Knn(stuV0);
        System.out.println(student.toString());
    }

    /**
     * 找出同等级占比最多的学生等级
     * @param categoryList
     * @return
     */
    public static String getCategoryStudent(List<Student> categoryList){
    
    
        int tallCount = 0;
        int midCount = 0;
        int smallCount = 0;
        for (Student stuU : categoryList) {
    
    
            if (stuU.getRank().equals("高")) tallCount++;
            else if (stuU.getRank().equals("中等")) midCount++;
            else smallCount++;
        }
        int max = 0;
        max = tallCount > midCount ? tallCount : midCount;
        max = smallCount > max ? smallCount : max;
        if (smallCount == max) return "矮";
        else if (tallCount == max) return "高";
        else return "中等";
    }

    /**
     * 计算出stuV0距离categoryList集合中最远的学生对象
     * @param stuV0
     * @param categoryList
     * @return
     */
    public static Student getCalculate(Student stuV0, List<Student> categoryList) {
    
    
        int maxHeight = 0; //存放stuV0与类别集合categoryList的最远距离
        Student resultStu = new Student(); //存放要返回的学生,即stuV0与类别集合categoryList距离最远的学生
        for (Student stuU : categoryList) {
    
    
            int v0ToU = Math.abs(stuV0.getHeight() - stuU.getHeight()); //stuV0与stuU的距离
            if (v0ToU > maxHeight){
    
     //stuV0与stuU的距离大于maxHeight,则对maxHeight和resultStu进行更新
                maxHeight = v0ToU;
                resultStu = stuU;
            }
        }
        return resultStu;
    }

    /**
     * 对输入学生类进行Knn算法实例化该学生的等级后,将该学生返回
     * @param stuV0
     * @return
     */
    public static Student Knn(Student stuV0){
    
    
        List<Student> categoryList = new ArrayList<>(); //存放距离stuV0最近的k个学生,最初存放数据集的前5项
        for (int i = 0; i < dataList.size(); i++) {
    
    
            if (i < 5) categoryList.add(dataList.get(i));
            else {
    
    
                //stuV0距离剩下数据集中某项的距离
                int v0Tod = Math.abs(stuV0.getHeight() - dataList.get(i).getHeight());
                Student stuU =  getCalculate(stuV0, categoryList); //存放stuV0距离类别集合中最远的学生
                int uToV0 = Math.abs(stuU.getHeight() - stuV0.getHeight());
                if (uToV0 > v0Tod){
    
    
                    categoryList.remove(stuU); //在集合列表中去除stuU
                    categoryList.add(dataList.get(i));
                }
            }
        }
        System.out.println(categoryList.toString());
        String rank = getCategoryStudent(categoryList);
        stuV0.setRank(rank);

        return stuV0;
    }


    /**
     * 初始化数据
     */
    public static void initData(){
    
    
        Student s1 = new Student("李丽", 150, "矮");
        Student s2 = new Student("吉米", 192, "高");
        Student s3 = new Student("马大华", 170, "中等");
        Student s4 = new Student("王晓华", 173, "中等");
        Student s5 = new Student("刘敏", 160, "矮");
        Student s6 = new Student("张强", 175, "中等");
        Student s7 = new Student("李秦", 160, "矮");
        Student s8 = new Student("王壮", 190, "高");
        Student s9 = new Student("刘冰", 168, "中等");
        Student s10 = new Student("张喆", 178, "中等");
        Student s11 = new Student("杨毅", 170, "中等");
        Student s12 = new Student("徐田", 168, "中等");
        Student s13 = new Student("高杰", 165, "矮");
        Student s14 = new Student("张晓", 178, "中等");

        dataList.add(s1);
        dataList.add(s2);
        dataList.add(s3);
        dataList.add(s4);
        dataList.add(s5);
        dataList.add(s6);
        dataList.add(s7);
        dataList.add(s8);
        dataList.add(s9);
        dataList.add(s10);
        dataList.add(s11);
        dataList.add(s12);
        dataList.add(s13);
        dataList.add(s14);
    }
}

Resultados experimentales
inserte la descripción de la imagen aquí
Además, además de generar la calificación de altura del estudiante requerida por la pregunta, el autor también genera el grupo donde se encuentra el estudiante de entrada para comparar la pregunta y asegurarse de que el resultado sea correcto.
inserte la descripción de la imagen aquí

5. Resumen experimental

El autor no garantiza que los resultados de este experimento sean correctos, solo proporciona una idea para implementar el algoritmo KNN utilizando el lenguaje Java. Debido a que el experimento no dio una respuesta, el autor ingresó los datos experimentales con la respuesta en Internet en el programa, y ​​el resultado del programa es consistente con la respuesta, por lo que el problema no debería ser grande. Si hay algo que no está bien escrito, ¡por favor dame algunos consejos!
La página de inicio del autor también tiene un resumen de otros algoritmos de minería de datos, ¡bienvenido a patrocinar!

Supongo que te gusta

Origin blog.csdn.net/qq_54162207/article/details/128366181
Recomendado
Clasificación