[Python-代码实现]统计学习方法之感知机模型

版权声明:点个赞,来个评论(夸我),随便转~ https://blog.csdn.net/qq_28827635/article/details/84760401

内容简介

  1. 感知机模型 - 手写 Coding
  2. 使用手写模型进行鸢尾花分类
  3. 使用 sklearn 中的感知机进行鸢尾花分类

感知机模型 - 手写 Coding

class Model:
"""感知机模型"""
    def __init__(self, data):
        """选取初值 w, b, η"""
        self.w = np.zeros(len(data[0]) - 1, dtype=np.float32)  # 参数 w 应与 x 等量
        self.b = 0
        self.η = 0.1

    def sign(self, x):
        """感知机模型"""
        y = np.dot(self.w, x) + self.b
        return 1 if y >= 0 else -1

    def fit(self, x_train, y_train):
        """模型训练"""
        while True:
            for d, x in enumerate(x_train):  # 取出一条数据
                y = y_train[d]  # 取出对应数据的 target
                if y * self.sign(x) <= 0:  # 分类不正确进行参数更迭
                    self.w = self.w + np.dot(self.η * y, x)
                    self.b = self.b + self.η * y
                    break  # 发生更迭即存在分类错误,从头再来
            else:  # 没有发生更迭即全部分类正确,停止训练
                break

        return self.w, self.b

使用手写模型进行鸢尾花分类

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt


def main():
    # 一、加载数据
    iris = load_iris()

    # 二、提取输入与输出数据
    # 为输入特征创建 Frame,并使用特征名称作为列标题(注意不是列索引)
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    # 添加输出列 target
    df['target'] = iris.target
    # 给 Frame 添加列索引(只有加了索引才可以使用索引)
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'target']
    # 打印输出值分布情况
    print(df.target.value_counts())

    # 三、绘出数据并观察分布情况
    # 通过 frame 能够看出数据是 50 间隔分布,因此可以以 50 间隔分别取出
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.legend()
    plt.show()

    # 四、特征提取与目标值提取
    # 使用 iloc 选取前 100 条数据的第 0, 1, -1 列,并转换为 array
    data = np.array(df.iloc[:100, [0, 1, -1]])
    # 将 第 0, 1 列数据赋值给 x,将 第 -1 列数据赋值给 y
    x_train, y_train = data[:, :-1], data[:, -1]
    # 将 y 值进行 1, -1分类
    y_train = np.array([i if i == 1 else -1 for i in y_train])

    # 五、感知机模型训练
    perceptron = Model(data)
    w, b = perceptron.fit(x_train, y_train)

    # 六、绘出判定边界
    # 分离超平面为 w[0]x_1 + w[1]x_2 + b = 0
    x_1 = np.linspace(4, 7, 10)
    x_2 = -(w[0] * x_1 + b) / w[1]
    plt.plot(x_1, x_2)
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    main()

使用 sklearn 中的感知机进行鸢尾花分类

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import Perceptron
import matplotlib.pyplot as plt


def main():
    # 一、加载数据
    iris = load_iris()

    # 二、提取输入与输出数据
    # 为输入特征创建 Frame,并使用特征名称作为列标题(注意不是列索引)
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    # 添加输出列 target
    df['target'] = iris.target
    # 给 Frame 添加列索引(只有加了索引才可以使用索引)
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'target']

    # 三、特征提取与目标值提取
    # 使用 iloc 选取前 100 条数据的第 0, 1, -1 列,并转换为 array
    data = np.array(df.iloc[:100, [0, 1, -1]])
    # 将 第 0, 1 列数据赋值给 x,将 第 -1 列数据赋值给 y
    x_train, y_train = data[:, :-1], data[:, -1]
    # 将 y 值进行 1, -1分类
    y_train = np.array([i if i == 1 else -1 for i in y_train])

    # 四、使用SKlearn感知机进行模型训练
    clf = Perceptron()
    clf.fit(x_train, y_train)
    w = clf.coef_[0]  # w
    b = clf.intercept_  # b

    # 五、绘出判定边界
    # 分离超平面为 w[0]x_1 + w[1]x_2 + b = 0
    x_1 = np.linspace(4, 7, 10)
    x_2 = -(w[0] * x_1 + b) / w[1]
    plt.plot(x_1, x_2)
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    main()

希望对你有所帮助,点个赞哇大兄dei!
个人博客:http://xingtu.info
GitHub:https://github.com/BreezeDawn/MachineLearning

猜你喜欢

转载自blog.csdn.net/qq_28827635/article/details/84760401