粒子群算法(1) - Python实现

  • 抽象来源:模仿自然界中的鸟群觅食行为。
  • 核心思想:在自然界鸟群觅食过程中,我们可以想象食物自身散发某种着香味(实际上可能不是,此处仅以香味为例代表鸟群可能获得的某种食物信息),该香味距离食物越近则越浓(以状态函数值进行描述)。并且,我们假设群体中每只鸟的飞行行为均且仅受到三方面因素的影响和贡献:1)每只鸟自身的飞行惯性 --- 自身惯性贡献;2)每只鸟的历史最优状态 --- 自身认知贡献;3)整个鸟群的历史最优状态 --- 群体经验贡献。注意,此处以与食物的距离(即香味浓度,也就是状态函数值)来判断飞行状态的优劣。根据此三类贡献,状态空间内每只鸟将逐渐调整自身的飞行速度(包括大小、方向),并向食物位置(即局部香味最浓的位置)汇聚。因此,在相关参数设置合理的前提下,粒子群算法的最终解应该对应于给定状态空间内的最值。
  • 迭代公式
    粒子群速度更新公式
    \begin{equation}
    V_i(t+1) = \omega V_i(t) + c_1r_1(pbest_i - X_i(t)) + c_2r_2(gbest-X_i(t))
    \end{equation}
    该式右端三项分别代表:自身惯性贡献、自身认知贡献以及群体经验贡献。其中,$\omega$代表惯性因子,$c_1$、$c_2$代表学习因子,$r_1$、$r_2$为$[0, 1]$之间的均匀随机数,$pbest_i$为第$i$个粒子已知的历史最优状态或位置,$gbest$整个粒子群已知的历史最优状态或位置。$V_i(t)$与$X_i(t)$分别代表$t$时刻粒子$i$的速度与位置。
    由于实际问题可能处于多维空间内,因此有:
    \begin{equation}
    \begin{cases}
    V_i = (v_{i1}, v_{i2}, ..., v_{iD})\\
    X_i = (x_{i1}, x_{i2}, ..., x_{iD})
    \end{cases}
    \end{equation}
    其中,$D$为空间维数。
    粒子群位置更新公式
    \begin{equation}
    X_i(t+1) = X_i(t) + V_i(t+1)
    \end{equation} 
  • Python代码实现
     1 import numpy as np
     2 import matplotlib.pyplot as plt
     3 import random
     4 
     5 
     6 # 定义“粒子”类
     7 class parti(object):
     8     def __init__(self, v, x):
     9         self.v = v                    # 粒子当前速度
    10         self.x = x                    # 粒子当前位置
    11         self.pbest = x                # 粒子历史最优位置
    12         
    13 class PSO(object):
    14     def __init__(self, interval, tab='min', partisNum=10, iterMax=1000, w=1, c1=2, c2=2):
    15         self.interval = interval                                            # 给定状态空间 - 即待求解空间
    16         self.tab = tab.strip()                                              # 求解最大值还是最小值的标签: 'min' - 最小值;'max' - 最大值
    17         self.iterMax = iterMax                                              # 迭代求解次数
    18         self.w = w                                                          # 惯性因子
    19         self.c1, self.c2 = c1, c2                                           # 学习因子
    20         self.v_max = (interval[1] - interval[0]) * 0.1                      # 设置最大迁移速度
    21         #####################################################################
    22         self.partis_list, self.gbest = self.initPartis(partisNum)                 # 完成粒子群的初始化,并提取群体历史最优位置
    23         self.x_seeds = np.array(list(parti_.x for parti_ in self.partis_list))    # 提取粒子群的种子状态 ###
    24         self.solve()                                                              # 完成主体的求解过程
    25         self.display()                                                            # 数据可视化展示
    26         
    27     def initPartis(self, partisNum):
    28         partis_list = list()
    29         for i in range(partisNum):
    30             v_seed = random.uniform(-self.v_max, self.v_max)
    31             x_seed = random.uniform(*self.interval)
    32             partis_list.append(parti(v_seed, x_seed))
    33         temp = 'find_' + self.tab
    34         if hasattr(self, temp):                                             # 采用反射方法提取对应的函数
    35             gbest = getattr(self, temp)(partis_list)
    36         else:
    37             exit('>>>tab标签传参有误:"min"|"max"<<<')
    38         return partis_list, gbest
    39         
    40     def solve(self):
    41         for i in range(self.iterMax):
    42             for parti_c in self.partis_list:
    43                 f1 = self.func(parti_c.x)
    44                 # 更新粒子速度,并限制在最大迁移速度之内
    45                 parti_c.v = self.w * parti_c.v + self.c1 * random.random() * (parti_c.pbest - parti_c.x) + self.c2 * random.random() * (self.gbest - parti_c.x)
    46                 if parti_c.v > self.v_max: parti_c.v = self.v_max
    47                 elif parti_c.v < -self.v_max: parti_c.v = -self.v_max
    48                 # 更新粒子位置,并限制在待解空间之内
    49                 if self.interval[0] <= parti_c.x + parti_c.v <=self.interval[1]:
    50                     parti_c.x = parti_c.x + parti_c.v 
    51                 else:
    52                     parti_c.x = parti_c.x - parti_c.v
    53                 f2 = self.func(parti_c.x)
    54                 getattr(self, 'deal_'+self.tab)(f1, f2, parti_c)             # 更新粒子历史最优位置与群体历史最优位置      
    55         
    56     def func(self, x):                                                       # 状态产生函数 - 即待求解函数
    57         value = np.sin(x**2) * (x**2 - 5*x)
    58         return value
    59         
    60     def find_min(self, partis_list):                                         # 按状态函数最小值找到粒子群初始化的历史最优位置
    61         parti = min(partis_list, key=lambda parti: self.func(parti.pbest))
    62         return parti.pbest
    63         
    64     def find_max(self, partis_list):
    65         parti = max(partis_list, key=lambda parti: self.func(parti.pbest))   # 按状态函数最大值找到粒子群初始化的历史最优位置
    66         return parti.pbest
    67         
    68     def deal_min(self, f1, f2, parti_):
    69         if f2 < f1:                          # 更新粒子历史最优位置
    70             parti_.pbest = parti_.x
    71         if f2 < self.func(self.gbest):
    72             self.gbest = parti_.x            # 更新群体历史最优位置
    73             
    74     def deal_max(self, f1, f2, parti_):
    75         if f2 > f1:                          # 更新粒子历史最优位置
    76             parti_.pbest = parti_.x
    77         if f2 > self.func(self.gbest):
    78             self.gbest = parti_.x            # 更新群体历史最优位置
    79             
    80     def display(self):
    81         print('solution: {}'.format(self.gbest))
    82         plt.figure(figsize=(8, 4))
    83         x = np.linspace(self.interval[0], self.interval[1], 300)
    84         y = self.func(x)
    85         plt.plot(x, y, 'g-', label='function')
    86         plt.plot(self.x_seeds, self.func(self.x_seeds), 'b.', label='seeds')
    87         plt.plot(self.gbest, self.func(self.gbest), 'r*', label='solution')
    88         plt.xlabel('x')
    89         plt.ylabel('f(x)')
    90         plt.title('solution = {}'.format(self.gbest))
    91         plt.legend()
    92         plt.savefig('PSO.png', dpi=500)
    93         plt.show()
    94         plt.close()
    95 
    96         
    97 if __name__ == '__main__':
    98     PSO([-9, 5], 'max')
    View Code

     笔者所用示例函数为:
    \begin{equation}
    f(x) = (x^2 - 5x)sin(x^2)
    \end{equation}

  • 结果展示
  • 参考:https://wenku.baidu.com/view/0fdb3dff87c24028905fc321.html

猜你喜欢

转载自www.cnblogs.com/xxhbdk/p/9229944.html
今日推荐