4.2 NEAT 监督学习 Supervised learning_哔哩哔哩_bilibili
接着我们来说说 neat-python 网页上的一个使用例子, 用 neat 来进化出一个神经网络预测 XOR 判断(一样的输出False,不一样输出True)
- 输入 True, True, 输出 False
- 输入 False, True, 输出 True
- 输入 True, False, 输出 True
- 输入 False, False 输出 False
在例子当中, 我们用这样的形式来代替要学习的 input 和 output:
xor_inputs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
xor_outputs = [ (0.0,), (1.0,), (1.0,), (0.0,)]
那怎么样来评价每个个体的适应度 (fitness), 或者说他的预测得分高低呢. 我们就对每个个体评分. 如果4个 xor 判断都预测对了就得4分, 用平方差来计算错的. 下面的 function 中就是根据每个 genome (DNA), 生成一个神经网络, 用这个神经网络预测, 再对这个 genome 打分, 并写入成它的 fitness:
def eval_genomes(genomes, config):
for genome_id, genome in genomes: # for each individual
genome.fitness = 4.0 # 4 xor evaluations
net = neat.nn.FeedForwardNetwork.create(genome, config)
for xi, xo in zip(xor_inputs, xor_outputs):
output = net.activate(xi)
genome.fitness -= (output[0] - xo[0]) ** 2
每一个 neat 的程序里有需要有这样的评分标准. 接着我们创建一个 config 的文件, 用来给定所有运行参数. 这个 config 文件要分开存储, 而且文件里要有一下几个方面的参数预设. 对于每个方面具体的预设值请参考我在 github 中的config-forward这个文件. 对于每个方面的解释, 不太明白的话, 请参考这里
现在我们就能使用这些预设的参数来生成一个 config
的值了 (上面的 eval_genomes
也用到了这个 config
).
local_dir = os.path.dirname(__file__)
config_file = os.path.join(local_dir, 'config-feedforward') # 参数文件
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_file)
有了这个 config
, 我们就能拿它来生成我们整个 population
, 使用这个初始的 p
来训练 300 次, 注意在 config-forward
中我们设置了一个参数 fitness_threshold = 3.9
, 就是说, 只要有任何一个 fitness 达到了 3.9 (最大4), 我们就停止迭代更新 population
. 所以有可能不到 300 次就学好了. 学好之后, 我们输出表现最好的 winner
.
p = neat.Population(config)
winner = p.run(eval_genomes, 300) # 输入计算 fitness 的方式和 generation 的次数
最主要的过程就完啦, 简单吧. 在这个例子脚本中的其他代码都是现实结果的代码, 大家随便看看就知道了.
print('\nOutput:')
winner_net = neat.nn.FeedForwardNetwork.create(winner, config)
for xi, xo in zip(xor_inputs, xor_outputs):
output = winner_net.activate(xi)
print("input {!r}, expected output {!r}, got {!r}".format(xi, xo, output))
我们通过这个来输出最后的 winner
神经网络预测结果, 不出意外, 你应该预测很准. 最后通过 visualize.py
文件的可视化功能, 我们就能生成几个图片, 使用浏览器打开 speciation.svg
看看不同种群的变化趋势, avg_fitness.svg
看看 fitness 的变化曲线, Digraph.gv.svg
看这个生成的神经网络长怎样.
关于最下面的那个神经网络图, 需要说明一下, 如果是实线, 如 B->1, B->2, 说明这个链接是 Enabled 的. 如果是虚线(点线), 如 B->A XOR B 就说明这个链接是 Disabled 的. 红色的线代表 weight <= 0, 绿色的线代表 weight > 0. 线的宽度和 weight 的大小有关.