4.2、NEAT 监督学习 Supervised learning

4.2 NEAT 监督学习 Supervised learning_哔哩哔哩_bilibili

NEAT 监督学习 | 莫烦Python


接着我们来说说 neat-python 网页上的一个使用例子, 用 neat 来进化出一个神经网络预测 XOR 判断(一样的输出False,不一样输出True)

  • 输入 True, True, 输出 False
  • 输入 False, True, 输出 True
  • 输入 True, False, 输出 True
  • 输入 False, False 输出 False

在例子当中, 我们用这样的形式来代替要学习的 input 和 output:

xor_inputs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
xor_outputs = [   (0.0,),     (1.0,),     (1.0,),     (0.0,)]

那怎么样来评价每个个体的适应度 (fitness), 或者说他的预测得分高低呢. 我们就对每个个体评分. 如果4个 xor 判断都预测对了就得4分, 用平方差来计算错的. 下面的 function 中就是根据每个 genome (DNA), 生成一个神经网络, 用这个神经网络预测, 再对这个 genome 打分, 并写入成它的 fitness:

def eval_genomes(genomes, config):
    for genome_id, genome in genomes:   # for each individual
        genome.fitness = 4.0        # 4 xor evaluations
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        for xi, xo in zip(xor_inputs, xor_outputs):
            output = net.activate(xi)
            genome.fitness -= (output[0] - xo[0]) ** 2

每一个 neat 的程序里有需要有这样的评分标准. 接着我们创建一个 config 的文件, 用来给定所有运行参数. 这个 config 文件要分开存储, 而且文件里要有一下几个方面的参数预设. 对于每个方面具体的预设值请参考我在 github 中的config-forward这个文件. 对于每个方面的解释, 不太明白的话, 请参考这里

现在我们就能使用这些预设的参数来生成一个 config 的值了 (上面的 eval_genomes 也用到了这个 config).

local_dir = os.path.dirname(__file__)
config_file = os.path.join(local_dir, 'config-feedforward')     # 参数文件
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                     neat.DefaultSpeciesSet, neat.DefaultStagnation,
                     config_file)

有了这个 config, 我们就能拿它来生成我们整个 population, 使用这个初始的 p 来训练 300 次, 注意在 config-forward 中我们设置了一个参数 fitness_threshold = 3.9, 就是说, 只要有任何一个 fitness 达到了 3.9 (最大4), 我们就停止迭代更新 population. 所以有可能不到 300 次就学好了. 学好之后, 我们输出表现最好的 winner.

p = neat.Population(config)
winner = p.run(eval_genomes, 300)   # 输入计算 fitness 的方式和 generation 的次数

最主要的过程就完啦, 简单吧. 在这个例子脚本中的其他代码都是现实结果的代码, 大家随便看看就知道了.

print('\nOutput:')
winner_net = neat.nn.FeedForwardNetwork.create(winner, config)
for xi, xo in zip(xor_inputs, xor_outputs):
    output = winner_net.activate(xi)
    print("input {!r}, expected output {!r}, got {!r}".format(xi, xo, output))

我们通过这个来输出最后的 winner 神经网络预测结果, 不出意外, 你应该预测很准. 最后通过 visualize.py 文件的可视化功能, 我们就能生成几个图片, 使用浏览器打开 speciation.svg 看看不同种群的变化趋势, avg_fitness.svg 看看 fitness 的变化曲线, Digraph.gv.svg 看这个生成的神经网络长怎样.

 

 

 关于最下面的那个神经网络图, 需要说明一下, 如果是实线, 如 B->1, B->2, 说明这个链接是 Enabled 的. 如果是虚线(点线), 如 B->A XOR B 就说明这个链接是 Disabled 的. 红色的线代表 weight <= 0, 绿色的线代表 weight > 0. 线的宽度和 weight 的大小有关.

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/130769880