巡回セールスマン問題-MCTS

import random
import math
import matplotlib.pyplot as plt
import networkx as nx

# 定义城市类
class City:
    def __init__(self, name, x, y):
        self.name = name
        self.x = x
        self.y = y

# 定义旅行商问题类
class TravelingSalesman:
    def __init__(self, cities):
        self.cities = cities
        self.distances = self.calculate_distances()

    def calculate_distances(self):
        distances = {
    
    }
        for city1 in self.cities:
            distances[city1.name] = {
    
    }
            for city2 in self.cities:
                if city1 != city2:
                    distances[city1.name][city2.name] = math.sqrt((city1.x - city2.x) ** 2 + (city1.y - city2.y) ** 2)
        return distances

# 定义MCTS节点类
class Node:
    def __init__(self, state, parent=None):
        self.state = state  # 城市序列
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

# 选择节点
def select(node, C=1.0):
    while node.children:
        node = max(node.children, key=lambda child: child.value / child.visits + C * math.sqrt(2.0 * math.log(node.visits) / float(child.visits))) 
    return node

# 扩展节点
def expand(node, available_cities):
    city = random.choice(available_cities)
    new_state = node.state + [city]
    available_cities.remove(city)
    new_node = Node(new_state, parent=node)
    node.children.append(new_node)
    return new_node

# 模拟一次TSP路径并计算距离
def simulate(node, traveling_salesman):
    state = node.state.copy()
    random.shuffle(state)  # 随机重新排序城市序列
    total_distance = 0
    for i in range(len(state) - 1):
        total_distance += traveling_salesman.distances[state[i]][state[i + 1]]
    total_distance += traveling_salesman.distances[state[-1]][state[0]]  # 回到起始城市
    return total_distance

# 回溯更新节点信息
def backpropagate(node, value):
    while node:
        node.visits += 1
        node.value += value
        node = node.parent

# 可视化搜索树
def visualize_tree(root):
    G = nx.DiGraph()
    node_dict = {
    
    }  # 用于保存节点对象的字典

    def add_node_to_graph(node):
        if node not in node_dict:
            node_dict[node] = len(node_dict)  # 用于为节点分配唯一的标识符
        return node_dict[node]

    queue = [(None, root)]
    while queue:
        parent, node = queue.pop(0)
        parent_id = add_node_to_graph(parent)
        node_id = add_node_to_graph(node)
        G.add_node(node_id, label=", ".join(map(str, node.state)))
        if parent:
            G.add_edge(parent_id, node_id)
        for child in node.children:
            queue.append((node, child))

    pos = nx.spring_layout(G, seed=42)
    labels = nx.get_node_attributes(G, 'label')
    nx.draw(G, pos, labels=labels, with_labels=True, node_size=2000, node_color='lightblue', font_size=10)
    plt.show()

if __name__ == "__main__":
    cities = [
        City("A", 0, 0),
        City("B", 1, 2),
        City("C", 3, 1),
        City("D", 2, 4),
        City("E", 4, 3)
    ]
    iterations = 10
    initial_state = ["A"]  # 初始状态为从城市A开始

    root = Node(initial_state)
    available_cities = [city.name for city in cities if city.name != initial_state[0]]
    traveling_salesman = TravelingSalesman(cities)

    for _ in range(iterations):
        node_to_expand = select(root)
        if available_cities:
            expanded_node = expand(node_to_expand, available_cities)
            value = simulate(expanded_node, traveling_salesman)
            backpropagate(expanded_node, value)
            visualize_tree(root)  # 可视化搜索树

おすすめ

転載: blog.csdn.net/weixin_40860393/article/details/133326082
おすすめ