Java programmers combat machine learning - starting with clustering algorithms

This article is suitable for programmers with programming experience. It is a "Hello world!" of machine learning. There is no theoretical knowledge. Those who care about the accuracy of the theory, please detour.

foreword

    Artificial intelligence is undoubtedly one of the hottest technical topics in recent years. Artificial intelligence technology represented by machine learning has gradually penetrated into all aspects of our lives. As long as anything is touched by machine learning, it seems to become taller. . As programmers in the tide of technology, we are so close to machine learning, but

        "Only in this mountain, the depths of clouds do not know where".

    Why use Java/Kotlin?

    It is undeniable that Python is the mainstream language in machine learning, but in my actual machine learning project, Python is suitable for algorithm research, and its stability and ecology are difficult to support a large-scale application. With Spark, dl4j, etc. With the popularity of a series of Java components, it is foreseeable that Java will be the mainstream platform for large-scale machine learning applications.

    It can be seen from this that the application of machine learning technology is one of the core competencies of Java programmers in the future, but as programmers, how do we get started with machine learning? Here we put aside those complicated concepts in machine learning, and start with the most representative clustering algorithm in machine learning.

    Yes, I "cheated" you in in the name of Java, but I believe that people with good Java foundation have no problem reading the following Kotlin code, and the following code can also be translated into Java code, which happens to be a very meaningful practice. The reason why the sample code in this article uses Kotlin is that Kotlin can express my method more concisely, and the compatibility with Java is quite perfect.

 

only background knowledge

       Machine learning has countless classifications and specific methods. Clustering algorithm or K-means clustering is undoubtedly one of the most representative unsupervised learning methods. It is as simple as many ordinary statistical algorithms, but it has training , prediction and other capabilities are very close to deep learning in use, and are an excellent algorithm for entry-level machine learning.

       Here is an easy-to-understand explanation of K-means clustering (k-means) in the author's own language:

An automatic classification algorithm: Classify a bunch of object sets with similar numerical properties into K categories, and through continuous iteration, the data within the categories have the greatest similarity and the categories can be distinguished from each other to the greatest extent. .

       The road is simple, with a simple clustering algorithm, we can:

  1. Instead of manual work, perform faster automatic classification of massive user data ;
  2. According to the automatic clustering results, potential patterns are found , such as: the father who buys diapers often buys himself a few more bottles of beer;
  3. Classify or predict new data more quickly through clustering results, such as: using historical data clustering results as a model to quickly predict someone's disease risk based on physical examination data;
  4. Accelerate the search speed of high-dimensional data , such as: clustering the gallery according to the depth features of the pictures, so as to quickly find the product set with the highest similarity from hundreds of millions of pictures through hierarchical search (similar to Baidu Search Map, Taobao Polaroid Amoy)

Borrow the clustering algorithm comparison chart in the Apache Commons Math document to understand what the next clustering does:

Comparison of clustering algorithms

Different colors are used to represent different types of clusters in the figure, showing the effect of clustering of various two-dimensional datasets.

 

Hands

Original requirement:

The portal website of a company is divided into the following columns:

video

literature

comics

animation

car

navigation

Magazine

Mail

medical

securities

news

wallet

business world

    The operators have compiled the data of 20,000 user visits in this quarter, and hope to use these data to profile the users of this site, and further launch targeted marketing activities and accurately deliver advertisements.

Description: The data file is a csv file separated by ",". The first column is the user id, and the next 13 columns are the user's visits to each column.

 

Analysis steps:

  1. Process data for analysis
  2. Clustering the processed data
  3. Interpret clustering categories as user classification portraits
  4. Propose targeted marketing activities based on user classification portraits
  5. Reach targeted marketing campaigns to every user

Code practice:

1. Create a project using Maven

mvn archetype:generate \
          -DinteractiveMode=false \
          -DarchetypeGroupId=org.jetbrains.kotlin \
          -DarchetypeArtifactId=kotlin-archetype-jvm \
          -DarchetypeVersion=1.3.70 \
          -DgroupId=org.ctstudio \
          -DartifactId=customer-cluster \
          -Dversion=1.0

After the command is executed, import the maven project with your favorite IDE.

2. Add dependencies

    We used commons-csv to parse the data, used the clustering algorithm provided by commons-math3, and also used the jdk8 extension feature of Kotlin by the way. In actual use, you can use your favorite csv components. Most components that support machine learning, such as Spark and Mahout, include the k-means clustering algorithm. As long as you master the basic usage, it is easy to replace as needed.

<!-- 使用kotlin8的jdk8扩展,主要是简化文件打开代码 -->
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-stdlib-jdk8</artifactId>
    <version>${kotlin.version}</version>
</dependency>

<!-- 用来导入、导出CSV格式的数据文件 -->
<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-csv</artifactId>
    <version>1.6</version>
</dependency>

<!-- 主要用到了其中的聚类算法 -->
<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.6.1</version>
</dependency>

 

2. Download data

 

Download the following two files locally for code use, for example, put them in the root directory of the aforementioned project:

 

 

3. Write the code

Read the data and structure it into a list of user PVs:

// 定义用户PV实体类,实现Clusterable以便聚类算法使用
// 其中id为第一列用户id,pv为double[]表示用户对各栏目的访问量,clusterId为分类,供保存结果时使用
class UserPV(var id: Int, private val pv: DoubleArray, var clusterId: Int = 0) : Clusterable {
    override fun getPoint(): DoubleArray {
        return pv
    }

    override fun toString(): String {
        return "{id:$id,point:${point.toList()}}"
    }
}

// 使用commons-csv读取数据文件为UserPV列表
fun loadData(filePath: String): List<UserPV> {
    val fmt = CSVFormat.EXCEL
    FileReader(filePath).use { reader ->
        return CSVParser.parse(reader, fmt).records.map {
            val uid = it.first().toIntOrNull() ?: 0
            val pv = DoubleArray(13) { i ->
                it[i + 1]?.toDoubleOrNull() ?: 0.0
            }
            UserPV(uid, pv)
        }
    }
}

Data preprocessing, remove abnormal data, process outliers in records, and normalize traffic

// 过滤或处理异常数据,实际业务中,可能需要做更多过滤或处理
// 过滤无效的用户id
val filteredData = originData.filter { it.id > 0 }
// 负数的访问量处理为0
filteredData.forEach { it.point.forEachIndexed { i, d -> if (d < 0.0) it.point[i] = 0.0 } }
// 对PV数据归一化
normalize(filteredData)

Normalization code:

fun <T : Clusterable> normalize(points: List<T>, dimension: Int = points.first().point.size) {
    val maxAry = DoubleArray(dimension) { Double.NEGATIVE_INFINITY }
    val minAry = DoubleArray(dimension) { Double.POSITIVE_INFINITY }
    points.forEach {
        maxAry.assignEach { index, item -> max(item, it.point[index]) }
        minAry.assignEach { index, item -> min(item, it.point[index]) }
    }
    // 此处用到了Kotlin的操作符重载,封装了对double[]元素的逐个元素操作
    val denominator = maxAry - minAry
    points.forEach {
        // 此处代码逻辑:(x - min)/(max - min)
        it.point.assignEach { i, item -> (item - minAry[i]) / denominator[i] }
    }
}

The so-called normalization refers to transforming all the data into the range of 0~1 by (value-min)/(max-min), so as to avoid affecting the clustering effect due to the large number of visits to a certain section.

Call the clustering algorithm on the data:

  // 创建聚类算法实例,"5"为想要归类的类别数量
  // 实际情况下包括k值在内的更多参数需要不断调整、聚类、评估来达到最佳的聚类效果
  val kMeans = KMeansPlusPlusClusterer<UserPV>(5)
  // 使用算法对处理后的数据进行聚类
  val clusters = kMeans.cluster(filteredData)

Often at the beginning, we do not know how many categories the data is most suitable for. At this time, we need to evaluate the algorithm to evaluate the clustering effect under different center points.

Calinski-Harabasz is a very commonly used evaluation algorithm. The basic idea is that the more compact the class and the larger the distance between classes, the higher the score. Unfortunately, there is no open source version of java yet. Fortunately, the code I submitted to Apache Commons Math has been accepted by commons-math4, and everyone can look forward to it. The Kotlin version I have written is used directly here , and you can also implement it yourself:

// 创建聚类算法
val kMeans = KMeansPlusPlusClusterer<UserPV>(5)
// 对数据集进行聚类
val clusters = kMeans.cluster(filteredData)
// 创建Calinski-Harabaszy评估算法
val evaluator = CalinskiHarabasz<UserPV>()
// 为刚才的聚类结果评分
val score = evaluator.score(clusters)

With the clustering and scoring codes, we need to dynamically select the most suitable k value, that is, the number of cluster centers:

    val evaluator = CalinskiHarabasz<UserPV>()
    var maxScore = 0.0
    var bestKMeans: KMeansPlusPlusClusterer<UserPV>? = null
    var bestClusters: List<CentroidCluster<UserPV>>? = null
    for (k in 2..10) {
        val kMeans = KMeansPlusPlusClusterer<UserPV>(k)
        val clusters = kMeans.cluster(filteredData)
        val score = evaluator.score(clusters)
        //挑选出分数最高的聚类簇
        if (score > maxScore) {
            maxScore = score
            bestKMeans = kMeans
            bestClusters = clusters
        }
        println("k=$k,score=$score")
    }

    //打印最佳的聚类中心数
    println("Best k is ${bestKMeans!!.k}")

By comparing the scores of multiple k values, we conclude that it is most appropriate to divide users into three categories . At this time, we can save the clustering results for analysis and interpretation

// 保存中心点数据
fun saveCenters(
    clusters: List<CentroidCluster<UserPV>>,
    fileCategories: String,
    fileCenters: String
) {
    // 从categories.csv中读取版块标题
    val categories = readCategories(fileCategories)
    // 保存按版块标题与聚类中心点
    writeCSV(fileCenters) { printer ->
        printer.print("")
        printer.printRecord(categories)
        for (cluster in clusters) {
            //每类用户数
            printer.print(cluster.points.size)
            //每类访问量均值
            printer.printRecord(cluster.center.point.toList())
        }
    }
}

...

saveCenters(clusters, "categories.csv", "centers.csv")

The classification of the user usually needs to be saved as the basis for providing personalized services for each user in the future:

//保存用户id-类别对应关系到csv文件
fun saveClusters(
    clusters: List<CentroidCluster<UserPV>>,
    fileClusters: String
) {
    writeCSV(fileClusters) { printer ->
        var clusterId = 0
        clusters.flatMap {
            clusterId++
            it.points.onEach { p -> p.clusterId = clusterId }
        }.sortedBy { it.id }.forEach { printer.printRecord(it.id, it.clusterId) }
    }
}
...
saveClusters(clusters, "clusters.csv")

Note that saving as CSV here is for demonstration only. Depending on the actual business, you may need to write the user id-category correspondence into the database.

4. Interpretation of clustering results

Using Excel to open the centers.csv file, we can mark the maximum value in each column (representing the normalized average number of visits of each type of user) with the background color as the characteristics of this type of user:

It is not difficult to see from the above table that our users can be divided into three categories:

  1. 7010 people like video, literature, anime
  2. 8151 people follow cars, navigation, magazines and mailboxes
  3. 4839 people like healthcare, securities, news, wallets and business

    If combined with other registration information of the user, we can even give some clearer portraits of the user, such as combining age and gender: college students who like movies and anime, professionals who are concerned about cars & fashion, housewives who are concerned about health & wealth management.. .

 

Summarize

    If you see this, you will find that it is not so difficult to get started with machine learning, the code runs whizzing, and it does not require too many frameworks and components. If your data is large enough, such as over 100 million, you can also expect the small batch k-means clustering algorithm I am contributing to Apache Commons Math (will be released with commons-math4), compared to switching to these frameworks such as Spark, the algorithm has The result is an exponential performance improvement. Of course, when your data is too large to be carried by a single machine, those distributed frameworks are still essential.

    If you want to learn machine learning well, mastering theoretical knowledge is essential. A journey of a thousand miles starts with one step. Let's start with mastering clustering algorithms. Besides this article, you also need to search for some theoretical knowledge of clustering algorithms. to deepen your understanding.

    Next time, I may use an easy-to-understand way to tell you some necessary prerequisite knowledge for in-depth (in fact, not too deep) machine learning, such as how to derive and understand multi-dimensional space from one-dimensional space, variance, and Euclidean distance. Of course I'm a master of practice, but not a master of theory, all of this knowledge is to lead me to a case of an actual AI project :-)

refer to

{{o.name}}
{{m.name}}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324133201&siteId=291194637