機械学習は知っておく必要があります-KNNアルゴリズムの詳細な分析と実装

一緒に書く習慣をつけましょう!「ナゲッツデイリーニュープラン・4月アップデートチャレンジ」に参加して2日目です。クリックしてイベントの詳細をご覧ください

1.実験的アルゴリズム設計

  1. スイカのデータセットを読む
  2. 最初のクラスター中心としてk個のサンプルをランダムに選択します
  3. 各サンプルと各クラスター中心間の距離を計算し、各サンプルをそれに最も近いクラスター中心に割り当てます。この時点で、すべてのサンプルはk個のグループに分割されています。
  4. 各グループのサンプルの平均をグループの新しいクラスターセンターとして使用して、クラスターセンターを更新します
  5. クラスタの中心が安定するか、最大反復回数に達するまで、2番目と3番目の手順を繰り返します。

実験分析

スイカのデータセットでKNNを使用した分類

スイカのデータセットを簡単に分析すると、次の結果が得られます。

特徴 データ機能 役割
シリアルナンバー 離散 シリアルナンバー
密度 連続 特徴
糖度 連続 特徴
良いメロン 離散 ラベル

したがって、クラスター分析のために特徴密度と糖度が選択されました。

次に、K-meansクラスター分析のコアコード

  1. 必要なライブラリをインポートする

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    复制代码

    この実験では、データセットを読み取るためのメインツールとしてpandas 1を選択し、主要な数学演算を高速化するためにnumpy 2選択し、データ視覚化分析のためにmatplotlib3選択しました。

  2. K-Meansクラスタリングの定義KMeans

    1. __init__()分類器を初期化するために定義する

      class KNNClassifier:
          def __init__(self, x: pd.DataFrame):
              self.x = x
          ...
      复制代码

      ここでX、はデータセットの特徴を表します。

    2. 事前定義された距離関数distanceAll()

      def distanceAll(center, rest):
          distances = np.apply_along_axis(_distances, 1, rest, center)
          return distances.sum()
      
      def _distances(point: np.ndarray, centers: np.ndarray):
          distances = np.apply_along_axis(_distance, 1, centers, point)
          return distances
      
      def _distance(x, y):
          return np.sqrt(np.dot(x, x) - 2 * np.dot(x, y) + np.dot(y, y))
      复制代码

      ここでいくつかの最適化を行いました。具体的な最適化のポイントは次のとおりです。

      for-loopより速く実行することは避けてください

      最初の関数distanceAllでは、入力されるcenter合計restは多次元行列です。ここでは、center合計とrest相互の距離関数が実装され、ループは使用されないforため、実行速度が大幅に向上します。

      _distance(x, y) 計算結果を再利用する

      ユークリッド距離の一般的な計算式は次のとおりです。

      d = (( k = 1 m x k i x k j 2 ) d = \left( \sum_{k=1}^m \left | x_{ki} - x_{kj} \right |^2 \right)

      但是我在此处使用的公式为其展开形式

      d = k = 1 m ( x k i 2 2 × x k i × x k j + x k j 2 ) d = \sum_{k=1}^m\left( \red{ x_{ki}^2} - 2\times x_{ki}\times x_{kj}+ \red {x_{kj}^2} \right)

      此公式中红色部分在计算欧氏距离时会多次使用,因此,使用此公式可以充分利用numpy的缓存机制,减少不必要的重复运算量。

    3. 预定义 allocate() 核心方法为每个点找到最近的聚类中心

      def allocateAll(center, rest):
          # 2. 计算每个样本到各个聚类中心之间的距离,将每个样本分配给距离它最近的聚类中心
          allocates = np.apply_along_axis(_allocate, 1, rest, center)
          # sns.scatterplot(data=rest, x=0, y=1, hue=allocates)
          copied = rest.copy()
          copied["allocations"] = allocates
          groups = copied.groupby("allocations").groups
          # 绘图
          fig = plt.figure()
          ax = rest.plot.scatter(x=0, y=1, c=allocates, colormap='viridis', legend=True)
          center.iloc[list(groups.keys())].plot.scatter(x=0,
                                                        y=1,
                                                        c=list(groups.keys()),
                                                        marker="x",
                                                        colormap='viridis',
                                                        s=200,
                                                        ax=ax)
          plt.show()
          return groups
      
      def _allocate(point: np.ndarray, centers: np.ndarray):
          distances = np.apply_along_axis(_distance, 1, centers, point, "euclidean")
          nearest_center = np.argmin(distances)
          return nearest_center
      复制代码

      同时,在对每个点寻找中心进行聚类的过程中,还集合了绘图可视化方法。此处的可视化方法将绘制出之后聚类的过程。

    4. 定义 train() 在训练集上进行迭代训练

      class KMeans:
          ...
          def train(self, k):
              print(f" === k = {k} === ")
              batch = self.x.shape[0]
              features = self.x.shape[1]
              # 1. 随机选取 k 个样本作为初始的聚类中心
              index = np.random.randint(0, batch, size=k)
              centers: pd.DataFrame = self.x.iloc[index]  # 聚类中心
              # rest: pd.DataFrame = self.x.loc[~self.x.index.isin(index)]
              allocations = allocateAll(centers, self.x)
              for i in range(10):
                  last_centers = centers
                  centers = np.empty((k, 2))
                  for label, points in allocations.items():
                      center = self.x.iloc[points]
                      new_center = np.average(center, axis=0)
                      centers[label] = new_center
                  if np.isclose(last_centers, centers).all():
                      print(f"k = {k} 收敛,停止!")
                      return distanceAll(pd.DataFrame(centers), self.x)
                  allocations = allocateAll(pd.DataFrame(centers), self.x)
      复制代码

      在本段代码中,我指定每次训练最多进行10轮,一般来说,只需要迭代5次即可收敛到聚类中心。

      代码分为两部分,第一次的聚类中心在样本中随机选取,进行第一次聚类之后,再依据上一次的聚类结果,选择每一类的均值点作为中心进行循环迭代,当下一轮迭代的循环中心与上一轮相差不大时,终止迭代,返回此时的wss距离值。

三、实验数据及结果分析

在西瓜数据集上使用K均值聚类

  1. 导入所需库

    import matplotlib.pyplot as plt
    import pandas as pd
    
    from model import KMeans
    复制代码

    此处导入刚刚编写的KMeans以及绘图工具matplotlib进行wss曲线的绘制。

  2. 读取数据集并构建模型

    df = pd.read_csv("kmeansdata.csv")
    model = KMeans(df[["m", "h"]])
    复制代码

    此处读入西瓜数据集,并选定特征mh构建模型。

  3. KMeans 模型训练,可视化,WSS曲线分析

    wss = []
    for i in range(2, 10):
        wss.append(model.train(k=i))
    plt.plot(range(2, 10), wss)
    plt.savefig("result.png")
    复制代码

    此处我在2到15中选择k值,分别使用这些k值在KMeans模型上进行训练,并保存每一次训练之后返回的wss距离,最后对wss距离进行可视化分析。

    训练过程可视化 k=3

    首先,在数据集中随机选取三个样本作为聚类中心:

    image.png

    可以看出,选择的聚类中心偏下,然后进行第一次迭代:

    image.png

    在每一类中,选择其中心点作为下一次聚类中心,然后对每个点重新决定其类别,并进行下一次迭代:

    image.png

    可以看出,此时中心往中间偏移,分类更加合理。再进行一次迭代:

    image.png

    image.png

    此后迭代中心不再产生明显变化,代表聚类中心收敛,本轮聚类结束。

    WSS曲线可视化

    image.png

四、总结及心得体会

  1. 在简单的数据集(如西瓜数据集)上,聚类效果较好,在几次迭代内便可达到收敛。
  2. 根据对不同k值的可视化分析,可以发现,在k=3时达到"肘部",此时K为最优值,大于3的k值会因为类别过多而失去统计意义,k值太小会导致类别过少,使类内距离急剧上升。
  3. 使用C接口实现Python程序比使用Python-based-coding效率更高。
  4. 掌握了一些简单的数据可视化方法,学会使用一些简单的matplotlib库中有关pyplot的函数,利用简单的数据可视化方法将大量的数据转化成图片,极大地简化了我们对结果数据的分析和比对,能够更轻易的获得一些结果上的规律和结论。

5.実験のプロセス、方法、手段の改善のための提案

  1. データセットを視覚化する場合、視覚分析のために高次元の特徴の最初の2つの次元を大まかに選択すると、他の次元の特徴情報が失われます。ここでは、PCA 4などの次元削減方法を選択して、高次元を投影できます。 2次元平面上の次元フィーチャ。視覚分析を実行します。
  2. より複雑なデータセットを試すことができます。
  3. より多くの距離関数を検討することができます。

参考文献

おすすめ

転載: juejin.im/post/7083118254645837861