人工智能----A*(星)算法之八数码问题


前言

本文将介绍A*算法在八数码问题里的应用,包括:

  1. 将牌为估价函数的A*算法;
  2. 曼哈顿算法;
  3. 宽度优先搜索算法
    这里将着重介绍将牌为估价函数的A*算法,2和3将直接给出完整可运行的代码。

一、什么是将牌?

观察下面两个矩阵:

初始状态

[[2, 8, 3],
[1, 6, 4],
[7, 0, 5]]

目标状态

[[1, 2, 3],
[8, 0, 4],
[7, 6, 5]]
对于初始状态的每一个元素,都去目标状态的相同位置进行对比,看两元素是否相同,不同则计数加1,对比全部结束后,这个计数就是将牌。即将牌就是未到达指定位置元素的个数。

二、编程步骤

1.状态空间图的定义


#本实验的状态空间图是一个3*3的九宫格,里面含有1-8的数字,
#空格用0来表示,由此,状态空间图可以使用数组来表示(这里用numpy表示)
import numpy as np 
start_data = np.array([[2, 8, 3], [1, 6, 4], [7, 0, 5]])#初始状态end_data = np.array([[1, 2, 3], [8, 0, 4], [7, 6, 5]])#目标状态

2.各种操作的定义

(2)	各种操作的定义
#实际中的九宫格移动是涉及到1-8八个数字的上下左右的移动的,
#但我们可以将之巧妙地转化成数字0与待移动数子的位置交换,
#因此只需要定义数子0的位置移动即可。移动数字前首先判断是否可以移动。
#这里举例右移操作:
def swap(num_data,direction):
    x,y = find_zero(num_data)#空格(0)的位置
    num = np.copy(num_data)#当前状态
    if direction == 'right':
        if y == 2:
            print('不能右移')
            return num
        num[x][y] = num[x][y+1]
        num[x][y+1] = 0
        return num

3.A*算法框架的定义

import queue 
#步骤1.定义Node类 
#步骤2.定义opened表和closed表,创建初始节点,并且加入opened表中
def method_a_function(): 
while len(opened.queue) != 0: 
# 步骤3.取队首元素节点 
# 步骤4.判断当前节点的data是否为目标值,是则返回正确值
# 步骤5.将取出的节点加入closed表中 
# 步骤6.对当前节点data中的空格0执行左、右、上、下移动的四个操作,创建当前节点的一切后继子节点,并对opened表进行更新
 for action in ['left', 'right', 'up', 'down']: 
# 创建子节点 
# 判断是否在closed表中,如果不在close表中,将其加入opened表(并且考虑到与opened表中已有元素重复的更新情况) 
#步骤7. 以节点f_loss值,为opened表进行排序(f_loss从小到大)
#步骤8.将初始状态到目标状态所经过的所有节点进行输出
其中method_a_function函数定义:
 def creat_sort_update():
    con = 0#轮数
    extern = 0#拿来进行扩展的节点数
    generate = 0#扩展出来的节点数
    while len(opened.queue) != 0:
        #步骤3.取队首元素节点
        node = opened.get()
        #步骤4.判断当前节点的data是否为目标值,是则返回正确值
        if(node.data == end_data).all():
            print(f'总共耗费{
      
      con}轮')
            print(f'扩展子节点数为{
      
      extern}')
            print(f'生成子节点数为{
      
      generate}')
            return node
        #步骤5.将取出的节点加入closed表中
        closed[data_to_int(node.data)] = 1
        extern += 1
        #步骤6.对当前节点data中的空格0执行各种操作,创建所有后继子节点
        for action in ['left', 'right', 'up', 'down']:
            #创建子节点
            child_node = Node(swap(node.data, action), node.step + 1, node)
            generate += 1
            index = data_to_int(child_node.data)
            #判断是否在closed表中,如果不在closed表中,将其加入opened表(并且考虑与opened表中已有元素重复的更新情况)
            if index not in closed:
                refresh_opened(child_node)
        #步骤7.根据节点的f_loss值,为opened表进行排序(f_close值从小到大)
        sorted_by_floss()
        con += 1
    print("该问题无解")

sorted_by_floss函数的定义:
(本质是冒泡排序法)
def sorted_by_floss():
    tmp_opened = opened.queue.copy()
    length = len(tmp_opened)
    #排序,从小到大,当一样的时候按照step的大小排序
    for i in range(length):
        for j in range(length):
            if tmp_opened[i].f_loss < tmp_opened[j].f_loss:
                tmp = tmp_opened[i]
                tmp_opened[i] = tmp_opened[j]
                tmp_opened[j] = tmp
            if tmp_opened[i].f_loss == tmp_opened[j].f_loss:
                if tmp_opened[i].step > tmp_opened[j].step:
                    tmp = tmp_opened[i]
                    tmp_opened[i] = tmp_opened[j]
                    tmp_opened[j] = tmp
    opened.queue = tmp_opened

refresh_oprened函数定义:
def refresh_opened(now_node):
    tmp_opened = opened.queue.copy()
    for i in range(len(tmp_opened)):
        data = tmp_opened[i]
        now_data = now_node.data
        if(data == now_data).all():
            data_f_loss = tmp_opened[i].f_loss
            now_data_f_loss = now_node.f_loss
            if data_f_loss <= now_data_f_loss:
                return False
            else:
                print('')
                tmp_opened[i] = now_node
                opened.queue = tmp_opened#更新之后的opened表还原
                return True
    #能走到这一步说明扩展的子节点是全新的
    tmp_opened.append(now_node)
    opened.queue = tmp_opened #更新之后的opened表还原
    return True

三、3种算法的完整代码

1.将牌算法

#!/usr/bin/env python
# coding: utf-8

# In[41]:


import numpy as np


# In[42]:


#用数组表示状态空间
start_data = np.array([[2,8,3],[1,6,4],[7,0,5]])#初始状态
end_data = np.array([[1,2,3],[8,0,4],[7,6,5]])#目标状态


# In[43]:


#定义各种操作
def find_zero(num):
    tmp_x, tmp_y = np.where(num == 0)
    return tmp_x[0], tmp_y[0]
def swap(num_data,direction):
    x,y = find_zero(num_data)#空格(0)的位置
    num = np.copy(num_data)#当前状态
    if direction == 'left':
        if y == 0:
            #print('不能左移')
            return num
        num[x][y] = num[x][y-1]
        num[x][y-1] = 0
        return num
    if direction == 'right':
        if y == 2:
            #print('不能右移')
            return num
        num[x][y] = num[x][y+1]
        num[x][y+1] = 0
        return num
    if direction == 'up':
        if x == 0:
            #print('不能上移')
            return num
        num[x][y] = num[x-1][y]
        num[x-1][y] = 0
        return num
    if direction == 'down':
        if x == 2:
            #print('不能下移')
            return num
        num[x][y] = num[x+1][y]
        num[x+1][y] = 0
        return num
    else:
        print('输入的指令不属于{left, right, up, down}')


# In[44]:


#A*算法定义
import queue
#定义Node类
class Node:
    f_loss = -1#启发值
    step = 0   #初始状态到当前状态的距离(步数)
    parent = None#父节点
    #步骤1.用状态和步数构造节点对象
    def __init__(self, data, step, parent):
        self.data = data
        self.step = step
        self.parent = parent
        #计算f(n)的值
        self.f_loss = cal_wcost(data) + step
        #计算w(n)的值
def cal_wcost(num):
    #计算w(n)的值,以及放错元素的个数
    #param num:要比较的数组的值
    #return:返回w(n)的值

    #先用老师提供的将牌计算w(n)的方法
    con = 0
    for i in range(3):
        for j in range(3):
            tmp_num = num[i][j]
            compare_num = end_data[i][j]
            if tmp_num != 0:
                con += tmp_num != compare_num
    return con


# In[46]:


#步骤2.定义opened表和closed表,创建初始节点,并且加入opened表中
def data_to_int(num):
    value = 0
    for i in num:
        for j in i:
            value = value * 10 + j
    return value

opened = queue.Queue()#opened表
start_node = Node(start_data, 0, None)
opened.put(start_node)
closed = {
    
    }#close表
def refresh_opened(now_node):
    tmp_opened = opened.queue.copy()
    for i in range(len(tmp_opened)):
        data = tmp_opened[i]
        now_data = now_node.data
        if(data == now_data).all():
            data_f_loss = tmp_opened[i].f_loss
            now_data_f_loss = now_node.f_loss
            if data_f_loss <= now_data_f_loss:
                return False
            else:
                print('')
                tmp_opened[i] = now_node
                opened.queue = tmp_opened#更新之后的opened表还原
                return True
    #能走到这一步说明扩展的子节点是全新的
    tmp_opened.append(now_node)
    opened.queue = tmp_opened #更新之后的opened表还原
    return True
def sorted_by_floss():
    tmp_opened = opened.queue.copy()
    length = len(tmp_opened)
    #排序,从小到大,当一样的时候按照step的大小排序
    for i in range(length):
        for j in range(length):
            if tmp_opened[i].f_loss < tmp_opened[j].f_loss:
                tmp = tmp_opened[i]
                tmp_opened[i] = tmp_opened[j]
                tmp_opened[j] = tmp
            if tmp_opened[i].f_loss == tmp_opened[j].f_loss:
                if tmp_opened[i].step > tmp_opened[j].step:
                    tmp = tmp_opened[i]
                    tmp_opened[i] = tmp_opened[j]
                    tmp_opened[j] = tmp
    opened.queue = tmp_opened

def creat_sort_update():
    con = 0
    extern = 0
    generate = 0
    while len(opened.queue) != 0:
        #步骤3.取队首元素节点
        node = opened.get()
        extern += 1
        #步骤4.判断当前节点的data是否为目标值,是则返回正确值
        if(node.data == end_data).all():
            print(f'总共耗费{
      
      con}轮')
            print(f'扩展子节点数为{
      
      extern}')
            print(f'生成子节点数为{
      
      generate}')
            return node
        #步骤5.将取出的节点加入closed表中 
        closed[data_to_int(node.data)] = 1
        #步骤6.对当前节点data中的空格0执行各种操作,创建所有后继子节点
        for action in ['left', 'right', 'up', 'down']:
            #创建子节点
            child_node = Node(swap(node.data, action), node.step + 1, node)
            generate += 1
            index = data_to_int(child_node.data)
            #判断是否在closed表中,如果不在closed表中,将其加入opened表(并且考虑与opened表中已有元素重复的更新情况)
            if index not in closed:
                refresh_opened(child_node)
        #步骤7.根据节点的f_loss值,为opened表进行排序(f_close值从小到大)
        sorted_by_floss()
        con += 1
    print("该问题无解")
        #opened表排序函数见上
#步骤8.将初始状态到目标状态所经过的所有节点进行输出
result_node = creat_sort_update()
import prettytable as pt
#获取路径中所有节点的函数,依次获取目标节点的父节点,形成一条正确顺序的路径,然后使用循环将这条路径输出
def output_result(node):
    all_node = [node]
    for i in range(node.step):
        father_node = node.parent
        all_node.append(father_node)
        node = father_node
    return reversed(all_node)

node_list = list(output_result(result_node))
tb = pt.PrettyTable()
tb.field_names = ['step', 'data', 'f_loss']
for node in node_list:
    num = node.data
    tb.add_row([node.step, num, node.f_loss])
    if node != node_list[-1]:
        tb.add_row(['---', '--------', '---'])
print(tb)

2.曼哈顿算法

#!/usr/bin/env python
# coding: utf-8
import numpy as np

#用数组表示状态空间
start_data = np.array([[2,8,3],[1,6,4],[7,0,5]])#初始状态
end_data = np.array([[1,2,3],[8,0,4],[7,6,5]])#目标状态
#定义各种操作
def find_zero(num):
    tmp_x, tmp_y = np.where(num == 0)
    return tmp_x[0], tmp_y[0]
def swap(num_data,direction):
    x,y = find_zero(num_data)#空格(0)的位置
    num = np.copy(num_data)#当前状态
    if direction == 'left':
        if y == 0:
            #print('不能左移')
            return num
        num[x][y] = num[x][y-1]
        num[x][y-1] = 0
        return num
    if direction == 'right':
        if y == 2:
            #print('不能右移')
            return num
        num[x][y] = num[x][y+1]
        num[x][y+1] = 0
        return num
    if direction == 'up':
        if x == 0:
            #print('不能上移')
            return num
        num[x][y] = num[x-1][y]
        num[x-1][y] = 0
        return num
    if direction == 'down':
        if x == 2:
            #print('不能下移')
            return num
        num[x][y] = num[x+1][y]
        num[x+1][y] = 0
        return num
    else:
        print('输入的指令不属于{left, right, up, down}')
#A*算法定义
import queue
#定义Node类
class Node:
    f_loss = -1#启发值
    step = 0   #初始状态到当前状态的距离(步数)
    parent = None#父节点
    #步骤1.用状态和步数构造节点对象
    def __init__(self, data, step, parent):
        self.data = data
        self.step = step
        self.parent = parent
        #计算f(n)的值
        self.f_loss = cal_wcost(data) + step
        #计算w(n)的值

def find_other(end_data, tmp_num):
    for i in range(3):
        for j in range(3):
            if end_data[i][j] == tmp_num:
                return i, j

def cal_wcost(num):
    #计算w(n)的值,以及放错元素的个数
    #param num:要比较的数组的值
    #return:返回w(n)的值
    con = 0
    for i in range(3):
        for j in range(3):
            tmp_num = num[i][j]
            if tmp_num != 0:
                tmp_x, tmp_y = find_other(end_data, tmp_num)
                con += abs(i - tmp_x) + abs(j - tmp_y)
    return con
#步骤2.定义opened表和closed表,创建初始节点,并且加入opened表中
def data_to_int(num):
    value = 0
    for i in num:
        for j in i:
            value = value * 10 + j
    return value

opened = queue.Queue()#opened表
start_node = Node(start_data, 0, None)
opened.put(start_node)
closed = {
    
    }#close表
def refresh_opened(now_node):
    tmp_opened = opened.queue.copy()
    for i in range(len(tmp_opened)):
        data = tmp_opened[i]
        now_data = now_node.data
        if(data == now_data).all():
            data_f_loss = tmp_opened[i].f_loss
            now_data_f_loss = now_node.f_loss
            if data_f_loss <= now_data_f_loss:
                return False
            else:
                print('')
                tmp_opened[i] = now_node
                opened.queue = tmp_opened#更新之后的opened表还原
                return True
    #能走到这一步说明扩展的子节点是全新的
    tmp_opened.append(now_node)
    opened.queue = tmp_opened #更新之后的opened表还原
    return True
def sorted_by_floss():
    tmp_opened = opened.queue.copy()
    length = len(tmp_opened)
    #排序,从小到大,当一样的时候按照step的大小排序
    for i in range(length):
        for j in range(length):
            if tmp_opened[i].f_loss < tmp_opened[j].f_loss:
                tmp = tmp_opened[i]
                tmp_opened[i] = tmp_opened[j]
                tmp_opened[j] = tmp
            if tmp_opened[i].f_loss == tmp_opened[j].f_loss:
                if tmp_opened[i].step > tmp_opened[j].step:
                    tmp = tmp_opened[i]
                    tmp_opened[i] = tmp_opened[j]
                    tmp_opened[j] = tmp
    opened.queue = tmp_opened

def creat_sort_update():
    con = 0
    extern = 0
    generate = 0
    while len(opened.queue) != 0:
        #步骤3.取队首元素节点
        node = opened.get()
        extern += 1
        #步骤4.判断当前节点的data是否为目标值,是则返回正确值
        if(node.data == end_data).all():
            print(f'总共耗费{
      
      con}轮')
            print(f'扩展子节点数为{
      
      extern}')
            print(f'生成子节点数为{
      
      generate}')
            return node
        #步骤5.将取出的节点加入closed表中 
        closed[data_to_int(node.data)] = 1
        #步骤6.对当前节点data中的空格0执行各种操作,创建所有后继子节点
        for action in ['left', 'right', 'up', 'down']:
            #创建子节点
            child_node = Node(swap(node.data, action), node.step + 1, node)
            generate += 1
            index = data_to_int(child_node.data)
            #判断是否在closed表中,如果不在closed表中,将其加入opened表(并且考虑与opened表中已有元素重复的更新情况)
            if index not in closed:
                refresh_opened(child_node)
        #步骤7.根据节点的f_loss值,为opened表进行排序(f_close值从小到大)
        sorted_by_floss()
        con += 1
    print("该问题无解")
        #opened表排序函数见上
#步骤8.将初始状态到目标状态所经过的所有节点进行输出
result_node = creat_sort_update()
import prettytable as pt
#获取路径中所有节点的函数,依次获取目标节点的父节点,形成一条正确顺序的路径,然后使用循环将这条路径输出
def output_result(node):
    all_node = [node]
    for i in range(node.step):
        father_node = node.parent
        all_node.append(father_node)
        node = father_node
    return reversed(all_node)

node_list = list(output_result(result_node))
tb = pt.PrettyTable()
tb.field_names = ['step', 'data', 'f_loss']
for node in node_list:
    num = node.data
    tb.add_row([node.step, num, node.f_loss])
    if node != node_list[-1]:
        tb.add_row(['---', '--------', '---'])
print(tb)

3.广度优先搜索

#!/usr/bin/env python
# coding: utf-8

import numpy as np
#用数组表示状态空间
start_data = np.array([[2,8,3],[1,6,4],[7,0,5]])#初始状态
end_data = np.array([[1,2,3],[8,0,4],[7,6,5]])#目标状态


# In[56]:


#定义各种操作
def find_zero(num):
    tmp_x, tmp_y = np.where(num == 0)
    return tmp_x[0], tmp_y[0]
def swap(num_data,direction):
    x,y = find_zero(num_data)#空格(0)的位置
    num = np.copy(num_data)#当前状态
    if direction == 'left':
        if y == 0:
            #print('不能左移')
            return num
        num[x][y] = num[x][y-1]
        num[x][y-1] = 0
        return num
    if direction == 'right':
        if y == 2:
            #print('不能右移')
            return num
        num[x][y] = num[x][y+1]
        num[x][y+1] = 0
        return num
    if direction == 'up':
        if x == 0:
            #print('不能上移')
            return num
        num[x][y] = num[x-1][y]
        num[x-1][y] = 0
        return num
    if direction == 'down':
        if x == 2:
            #print('不能下移')
            return num
        num[x][y] = num[x+1][y]
        num[x+1][y] = 0
        return num
    else:
        print('输入的指令不属于{left, right, up, down}')


# In[57]:


#A*算法定义
import queue
#定义Node类
class Node:
    f_loss = -1#启发值
    step = 0   #初始状态到当前状态的距离(步数)
    parent = None#父节点
    #步骤1.用状态和步数构造节点对象
    
    def __init__(self, data, step, parent):
        self.data = data
        self.step = step
        self.parent = parent
        #计算f(n)的值
        self.f_loss = cal_wcost(data) + step
        #计算w(n)的值
def find_other(end_data, tmp_num):
    for i in range(3):
        for j in range(3):
            if end_data[i][j] == tmp_num:
                return i, j


# In[58]:


#步骤2.定义opened表和closed表,创建初始节点,并且加入opened表中
def data_to_int(num):
    value = 0
    for i in num:
        for j in i:
            value = value * 10 + j
    return value

opened = queue.Queue()#opened表
start_node = Node(start_data, 0, None)
opened.put(start_node)
closed = {
    
    }#close表
def refresh_opened(now_node):
    tmp_opened = opened.queue.copy()
    for i in range(len(tmp_opened)):
        data = tmp_opened[i]
        now_data = now_node.data
        if(data == now_data).all():
            return False
    #能走到这一步说明扩展的子节点是全新的
    tmp_opened.append(now_node)
    opened.queue = tmp_opened #更新之后的opened表还原
    return True

def creat_sort_update():
    con = 0
    generate = 0
    extern = 0
    while len(opened.queue) != 0:
        #步骤3.取队首元素节点
        node = opened.get()
        extern += 1
        #步骤4.判断当前节点的data是否为目标值,是则返回正确值
        if(node.data == end_data).all():
            print(f'总共耗费{
      
      con}轮')
            print(f'扩展子节点数为{
      
      extern}')
            print(f'生成子节点数为{
      
      generate}')
            return node
        #步骤5.将取出的节点加入closed表中 
        closed[data_to_int(node.data)] = 1
        #步骤6.对当前节点data中的空格0执行各种操作,创建所有后继子节点
        for action in ['left', 'right', 'up', 'down']:
            #创建子节点
            child_node = Node(swap(node.data, action), node.step + 1, node)
            generate += 1
            index = data_to_int(child_node.data)
            #判断是否在closed表中,如果不在closed表中,将其加入opened表(并且考虑与opened表中已有元素重复的更新情况)
            if index not in closed:
                refresh_opened(child_node)
        con += 1
    print("该问题无解")
        #opened表排序函数见上
#步骤8.将初始状态到目标状态所经过的所有节点进行输出
result_node = creat_sort_update()
import prettytable as pt
#获取路径中所有节点的函数,依次获取目标节点的父节点,形成一条正确顺序的路径,然后使用循环将这条路径输出
def output_result(node):
    all_node = [node]
    for i in range(node.step):
        father_node = node.parent
        all_node.append(father_node)
        node = father_node
    return reversed(all_node)

node_list = list(output_result(result_node))
tb = pt.PrettyTable()
tb.field_names = ['step', 'data', 'f_loss']
for node in node_list:
    num = node.data
    tb.add_row([node.step, num, node.f_loss])
    if node != node_list[-1]:
        tb.add_row(['---', '--------', '---'])
print(tb)

四、实验总结

  1. 编程过程中Python的报错大多还是缩进问题,但不是那种多一个或者少一个空格的缩进错误,而是for循环的嵌套错误,return语句的位置错误,这要求我在编程时要格外注意逻辑问题,因为程序很可能陷入永久循环或一步提前退出循环
  2. 估价函数在程序效率问题上起着决定性的作用,没有估价函数的宽度优先搜索带来的冗余搜索数量是惊人的,而不同估价函数之间的效率仍然存在差异,要求我们进一步研究新的估价函数以提升效率。
  3. 估价函数的优劣,本质上是看该算法是否能让每次考察的节点的矩阵都是当下最接近目标矩阵的,例如将牌算法中的排序并没有做到这一点,因此该算法相比于曼哈算法稍显粗糙。按此思路,若能证明某一个估价函数让每次考察的节点的矩阵都是当下最接近目标矩阵的节点,则说明估价函数算法带来的效率提升已经到达了顶峰。之后的效率提升,可能要依靠调整节点每次扩展时的扩展顺序也就是数字0的移动顺序了。

猜你喜欢

转载自blog.csdn.net/qq_50313560/article/details/124811409