感知机对偶形式

感知机:

在这里插入图片描述
符号函数:
在这里插入图片描述
选择误分类点到超平面的总距离作为损失函数:
距离:
在这里插入图片描述
误分类点:
在这里插入图片描述
误分类点距离
在这里插入图片描述
总距离
在这里插入图片描述

感知机对偶形式

在这里插入图片描述
在这里插入图片描述
过程
在这里插入图片描述
例题:
在这里插入图片描述
在这里插入图片描述
动态可视化代码:

# 以半动画的方式展示感知识机对偶问题的操作的合理性
# 给定初始点, 初始直线
import time

import matplotlib
import numpy as np
from matplotlib.colors import ListedColormap

matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.family']='sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt


cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

# 可视化展示
point_coordinates = np.array([2., 2.])  # 关键点坐标

line_a, line_b = 7, 9  # 初始线方程
line_c = - line_b * line_a  # 注意负号

bottom, up = -5, 10  # 视窗

# 关键点对应的“基线”的方程的分类,以平面展示
xx_plane = np.linspace(bottom - 0.5, up + 0.5, 300)
yy_plane = np.linspace(bottom - 0.5, up + 0.5, 300)
xx_plane, yy_plane = np.meshgrid(xx_plane, yy_plane)
class_plane = np.array([1 if np.dot(xi, point_coordinates) + 1 > 0 else 0 for xi in zip(
    xx_plane.ravel(), yy_plane.ravel())]).reshape(xx_plane.shape)

# 优化方程对应的分类,以整数点展示
dots = np.array([np.array([ii, jj]) for ii in range(bottom, up + 1) for jj in range(bottom, up + 1)
                 if ii != point_coordinates[0] or jj != point_coordinates[1]])
class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])


# 找到直线在视窗内的两个顶点  直线与视窗四线的交点的中间的两个
def window_cross(line_a, line_b, line_c):
    if line_a == 0 and line_b ==0:
        return (0, 0), (0, 0)
    elif line_a == 0:
        return (bottom, -line_c / line_b), (up,  -line_c / line_b)
    elif line_b == 0:
        return (-line_c / line_a, bottom), (-line_c / line_a, up)
    else:
        c1 = bottom, - 1 / line_b * (line_c + line_a * bottom)
        c2 = up, - 1 / line_b * (line_c + line_a * up)
        c3 = -1 / line_a * (line_c + line_b * bottom), bottom
        c4 = -1 / line_a * (line_c + line_b * up), up
        cross_points = [c1, c2, c3, c4]
        cross_points.sort()
        return cross_points[1:3]


# 初始状态展示
plt.figure(figsize=(8, 6))
plt.title('起始状态 同向')
plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light)
plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold)
color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r'
plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100)
# plt.grid()
plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k')
plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k')
cross = window_cross(line_a, line_b, line_c)
plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g')


# 迭代过程的动态展示
plt.close('all')
plt.figure(figsize=(8, 6))
plt.ion()

for ii in range(200):
    plt.cla()
    plt.title(f'epoch={
      
      ii+1}: ({
      
      line_a}) * x + ({
      
      line_b}) * y + ({
      
      line_c}) = 0', fontsize=20)
    if ii < 8:
        time.sleep(0.4)
    else:
        time.sleep(0.02)
    plt.pcolormesh(xx_plane, yy_plane, class_plane, cmap=cmap_light)
    plt.scatter(dots[:, 0], dots[:, 1], c=class_dots, cmap=cmap_bold)
    color_point = 'b' if point_coordinates[0] * line_a + point_coordinates[1] * line_b + line_c > 0 else 'r'
    plt.scatter(point_coordinates[0], point_coordinates[1], c=color_point, marker="v", s=100)
    plt.grid()
    plt.plot([bottom - 0.5, up + 0.5], [0, 0], c='k')
    plt.plot([0, 0], [bottom - 0.5, up + 0.5], c='k')
    cross = window_cross(line_a, line_b, line_c)
    plt.plot([cross[0][0], cross[1][0]], [cross[0][1], cross[1][1]], c='g')
    # 暂停
    if ii < 8:
        time.sleep(0.3)
    else:
        time.sleep(0.02)
    plt.pause(0.01)
    # 迭代更新
    line_a, line_b, line_c = line_a+point_coordinates[0], line_b+point_coordinates[0], line_c + 1
    class_dots = np.array([1 if np.dot(xi, [line_a, line_b]) + line_c > 0 else 0 for xi in dots])

# 关闭交互模式
plt.ioff()

# 图形显示
plt.show()

猜你喜欢

转载自blog.csdn.net/weixin_51756104/article/details/121225378
今日推荐