#!/usr/bin/env python# -*- coding: utf-8 -*-import sys
import math
import random
import numpy as np
AVAILABLE_CHOICES =[1,-1,2,-2]
AVAILABLE_CHOICE_NUMBER =len(AVAILABLE_CHOICES)
MAX_ROUND_NUMBER =10classState(object):"""
蒙特卡罗树搜索的游戏状态,记录在某一个Node节点下的状态数据,包含当前的游戏得分、当前的游戏round数、从开始到当前的执行记录。
需要实现判断当前状态是否达到游戏结束状态,支持从Action集合中随机取出操作。
"""def__init__(self):
self.current_value =0.0# For the first root node, the index is 0 and the game should start from 1
self.current_round_index =0
self.cumulative_choices =[]defget_current_value(self):return self.current_value
defset_current_value(self, value):
self.current_value = value
defget_current_round_index(self):return self.current_round_index
defset_current_round_index(self, turn):
self.current_round_index = turn
defget_cumulative_choices(self):return self.cumulative_choices
defset_cumulative_choices(self, choices):
self.cumulative_choices = choices
defis_terminal(self):# The round index starts from 1 to max round numberreturn self.current_round_index == MAX_ROUND_NUMBER
defcompute_reward(self):return-abs(1- self.current_value)defget_next_state_with_random_choice(self):
random_choice = random.choice([choice for choice in AVAILABLE_CHOICES])
next_state = State()
next_state.set_current_value(self.current_value + random_choice)
next_state.set_current_round_index(self.current_round_index +1)
next_state.set_cumulative_choices(self.cumulative_choices +[random_choice])return next_state
def__repr__(self):return"State: {}, value: {}, round: {}, choices: {}".format(hash(self), self.current_value, self.current_round_index,
self.cumulative_choices)classNode(object):"""
蒙特卡罗树搜索的树结构的Node,包含了父节点和直接点等信息,还有用于计算UCB的遍历次数和quality值,还有游戏选择这个Node的State。
"""def__init__(self):
self.parent =None
self.children =[]
self.visit_times =0
self.quality_value =0.0
self.state =Nonedefset_state(self, state):
self.state = state
defget_state(self):return self.state
defget_parent(self):return self.parent
defset_parent(self, parent):
self.parent = parent
defget_children(self):return self.children
defget_visit_times(self):return self.visit_times
defset_visit_times(self, times):
self.visit_times = times
defvisit_times_add_one(self):
self.visit_times +=1defget_quality_value(self):return self.quality_value
defset_quality_value(self, value):
self.quality_value = value
defquality_value_add_n(self, n):
self.quality_value += n
defis_all_expand(self):returnlen(self.children)== AVAILABLE_CHOICE_NUMBER
defadd_child(self, sub_node):
sub_node.set_parent(self)
self.children.append(sub_node)def__repr__(self):return"Node: {}, Q/N: {}/{}, state: {}".format(hash(self), self.quality_value, self.visit_times, self.state)deftree_policy(node):"""
蒙特卡罗树搜索的Selection和Expansion阶段,传入当前需要开始搜索的节点(例如根节点),根据exploration/exploitation算法返回最好的需要expend的节点,注意如果节点是叶子结点直接返回。
基本策略是先找当前未选择过的子节点,如果有多个则随机选。如果都选择过就找权衡过exploration/exploitation的UCB值最大的,如果UCB值相等则随机选。
"""# Check if the current node is the leaf nodewhile node.get_state().is_terminal()==False:if node.is_all_expand():
node = best_child(node,True)else:# Return the new sub node
sub_node = expand(node)return sub_node
# Return the leaf nodereturn node
defdefault_policy(node):"""
蒙特卡罗树搜索的Simulation阶段,输入一个需要expand的节点,随机操作后创建新的节点,返回新增节点的reward。注意输入的节点应该不是子节点,而且是有未执行的Action可以expend的。
基本策略是随机选择Action。
"""# Get the state of the game
current_state = node.get_state()# Run until the game overwhile current_state.is_terminal()==False:# Pick one random action to play and get next state
current_state = current_state.get_next_state_with_random_choice()
final_state_reward = current_state.compute_reward()return final_state_reward
defexpand(node):"""
输入一个节点,在该节点上拓展一个新的节点,使用random方法执行Action,返回新增的节点。注意,需要保证新增的节点与其他节点Action不同。
"""
tried_sub_node_states =[
sub_node.get_state()for sub_node in node.get_children()]
new_state = node.get_state().get_next_state_with_random_choice()# Check until get the new state which has the different action from otherswhile new_state in tried_sub_node_states:
new_state = node.get_state().get_next_state_with_random_choice()
sub_node = Node()
sub_node.set_state(new_state)
node.add_child(sub_node)return sub_node
defbest_child(node, is_exploration):"""
使用UCB算法,权衡exploration和exploitation后选择得分最高的子节点,注意如果是预测阶段直接选择当前Q值得分最高的。
"""# TODO: Use the min float value
best_score =-sys.maxsize
best_sub_node =None# Travel all sub nodes to find the best onefor sub_node in node.get_children():# Ignore exploration for inferenceif is_exploration:
C =1/ math.sqrt(2.0)else:
C =0.0# UCB = quality / times + C * sqrt(2 * ln(total_times) / times)
left = sub_node.get_quality_value()/ sub_node.get_visit_times()
right =2.0* math.log(node.get_visit_times())/ sub_node.get_visit_times()
score = left + C * math.sqrt(right)if score > best_score:
best_sub_node = sub_node
best_score = score
return best_sub_node
defbackup(node, reward):"""
蒙特卡洛树搜索的Backpropagation阶段,输入前面获取需要expend的节点和新执行Action的reward,反馈给expend节点和上游所有节点并更新对应数据。
"""# Update util the root nodewhile node !=None:# Update the visit times
node.visit_times_add_one()# Update the quality value
node.quality_value_add_n(reward)# Change the node to the parent node
node = node.parent
defmonte_carlo_tree_search(node):"""
实现蒙特卡洛树搜索算法,传入一个根节点,在有限的时间内根据之前已经探索过的树结构expand新节点和更新数据,然后返回只要exploitation最高的子节点。
蒙特卡洛树搜索包含四个步骤,Selection、Expansion、Simulation、Backpropagation。
前两步使用tree policy找到值得探索的节点。
第三步使用default policy也就是在选中的节点上随机算法选一个子节点并计算reward。
最后一步使用backup也就是把reward更新到所有经过的选中节点的节点上。
进行预测时,只需要根据Q值选择exploitation最大的节点即可,找到下一个最优的节点。
"""
computation_budget =1000# Run as much as possible under the computation budgetfor i inrange(computation_budget):# 1. Find the best node to expand
expand_node = tree_policy(node)# 2. Random run to add node and get reward
reward = default_policy(expand_node)# 3. Update all passing nodes with reward
backup(expand_node, reward)# N. Get the best next node
best_next_node = best_child(node,False)return best_next_node
defmain():# Create the initialized state and initialized node
init_state = State()
init_node = Node()
init_node.set_state(init_state)
current_node = init_node
# Set the rounds to playround=10for i inrange(round):print("Play round: {}".format(i +1))
current_node = monte_carlo_tree_search(current_node)print("Choose node: {}".format(current_node))if __name__ =="__main__":
main()
结果
Play round:1
Choose node: Node:-9223371902452567720, Q/N:-453.0/846, state: State:-9223371902452565539, value:-1.0,round:1, choices:[-1]
Play round:2
Choose node: Node:134402247183, Q/N:-424.0/1839, state: State:-9223371902452528622, value:1.0,round:2, choices:[-1,2]
Play round:3
Choose node: Node:-9223371902452527258, Q/N:-290.0/2802, state: State:134402248547, value:0.0,round:3, choices:[-1,2,-1]
Play round:4
Choose node: Node:-9223371902452525999, Q/N:-238.0/3787, state: State:134402249806, value:2.0,round:4, choices:[-1,2,-1,2]
Play round:5
Choose node: Node:134402253509, Q/N:-203.0/4776, state: State:-9223371902452522303, value:0.0,round:5, choices:[-1,2,-1,2,-2]
Play round:6
Choose node: Node:134402253523, Q/N:-150.0/5757, state: State:-9223371902452522289, value:1.0,round:6, choices:[-1,2,-1,2,-2,1]
Play round:7
Choose node: Node:-9223371902452521093, Q/N:-70.0/6718, state: State:-9223371902452521100, value:0.0,round:7, choices:[-1,2,-1,2,-2,1,-1]
Play round:8
Choose node: Node:-9223371902452520262, Q/N:-47.0/7709, state: State:134402255550, value:1.0,round:8, choices:[-1,2,-1,2,-2,1,-1,1]
Play round:9
Choose node: Node:134402256974, Q/N:-21.0/8695, state: State:-9223371902452518838, value:2.0,round:9, choices:[-1,2,-1,2,-2,1,-1,1,1]
Play round:10
Choose node: Node:134402257044, Q/N:0.0/9679, state: State:-9223371902452518768, value:1.0,round:10, choices:[-1,2,-1,2,-2,1,-1,1,1,-1]
Process finished with exit code 0