通过八数码难题来解释A*算法。
问题描述: 用一个3X3的方格阵来表示该问题的一个状态,每格放置1-8的一个数字,剩下一个空格(用0表示)。剩下一个只能通过数字(或空格)的移动来改变方格阵的状态。
要求:根据给定初始布局(即初始状态)和目标布局(即目标状态),如何移动数字才能从初始布局到达目标布局,找到合法的走步序列。
A*算法是一种有序搜索算法,其特点在于对估价函数的定义上。 估价函数f(n) = g(n) + h(n)。f(n)表示从起始节点S到节点n的最小代价g(n)与从节点n到目标节点的最小代价h(n)之和。讲h(n)定义为所有棋子距离目标位置的曼哈顿距离(与目标位置的水平距离和垂直距离之和)之和。 open表中具有最小f值得节点,就是下一步要扩展的节点。
/**
* This file is part of algorithm of Grid9.
*/
package algorithm;
import java.util.ArrayList;
import java.util.Comparator;
enum MoveAction{
up ,
right ,
down ,
left
}
public class GridNine {
private static ArrayList<TNode> open;
private static ArrayList<TNode> closed;
public static void main(String[] args) {
// TODO Auto-generated method stub
int[][] start = {{2,8,0}, {1,6,3}, {7,5,4}};
int[][] target = {{1,2,3},{8,0,4},{7,6,5}};
AStar(start, target, 3, 3);
}
public static void AStar(int[][] start, int[][] target, int rowSize, int colSize) {
TNode startNode = new TNode(start, rowSize, colSize);
startNode.setG(0);
startNode.setH(target);
System.out.println("start state: \r\n");
startNode.displayGrid();
TNode targetNode = new TNode(target,rowSize, colSize);
targetNode.setG(0);
targetNode.setH(target);
System.out.println("target state: \r\n");
targetNode.displayGrid();
// initialize
open = new ArrayList<TNode>();
closed = new ArrayList<TNode>();
// Step 1, put start node in open table;
open.add(startNode);
int step =0;
// Step 2: judge if open table is empty; failed if empty; otherwise go to next step
while(!open.isEmpty()) {
// Step 3: pop first node from open table , and put it in closed;
open.sort(new TNodeComparator());
TNode currNode = open.remove(0);
System.out.printf(" Step %d: \r\n", step++);
currNode.displayGrid();
closed.add(currNode);
if (isSameState(currNode, targetNode)) {
System.out.println("success");
return;
}
// expand the currNode to generate it's successor nodes
// by moving blank cell in the sequence of : up, right, down, left;
TNode nextNode =null;
for(int i=0; i< 4; i++) {
switch (i) {
case 0:
nextNode = move(currNode, MoveAction.up, target);
break;
case 1:
nextNode = move(currNode, MoveAction.right, target);
break;
case 2:
nextNode = move(currNode, MoveAction.down, target);
break;
case 3:
nextNode = move(currNode, MoveAction.left, target);
break;
default:
break;
}
if (nextNode != null) {
// check if nextNode is in open table
if (open.contains(nextNode)) {
TNode old = open.get(open.indexOf(nextNode));
// compare the gValue of nextNode with old;
if (nextNode.getG()< old.getG()) {
old.setParent(currNode);
old.setG(nextNode.getG());
}
}
else if (closed.contains(nextNode)) {
TNode old = closed.get(closed.indexOf(nextNode));
if (nextNode.getG()< old.getG()) {
old.setParent(currNode);
old.setG(nextNode.getG());
// remove old from closed table to open table;
closed.remove(old);
open.add(old);
}
}
else {
// add nextNode to open table if nextNode is not in open table either in closed table;
open.add(nextNode);
}
}
}
}
}
static TNode move(TNode node, MoveAction action, int[][] target){
// successor node is forbid to move back to same state as parent to avoid dead loop;
if (node.getForbiddenMove() == action)
return null;
int emptyRow = node.getEmptyRow();
int emptyCol = node.getEmptyCol();
int newRow =0;
int newCol =0;
int temp ; // temp variable for element moving;
TNode newNode = null;
boolean validMove = false;
MoveAction successorForbiddenMove = MoveAction.up;
switch(action) {
case down:
newRow = emptyRow +1;
newCol = emptyCol;
successorForbiddenMove = MoveAction.up;
break;
case left:
newRow = emptyRow ;
newCol = emptyCol-1;
successorForbiddenMove = MoveAction.right;
break;
case right:
newRow = emptyRow ;
newCol = emptyCol + 1;
successorForbiddenMove = MoveAction.left;
break;
case up:
// try move up
newRow = emptyRow -1;
newCol = emptyCol;
successorForbiddenMove = MoveAction.down;
break;
default:
break;
}
// check if this movement will cause a element exceed the boundary
validMove = exceedBoundary(newRow, newCol, node.getRowSize(), node.getColSize());
if (validMove) {
int[][] nextGrid = node.getGridClone() ;
// move empty element up by exchange two elements;
temp = nextGrid[newRow][newCol];
nextGrid[newRow][newCol] = 0;
nextGrid[node.getEmptyRow()][node.getEmptyCol()] = temp;
newNode = new TNode(nextGrid,node.getRowSize(), node.getColSize());
newNode.setG(node.getG() + 1);
newNode.setH(target);
newNode.setParent(node);
newNode.setForbiddenMove(successorForbiddenMove);
}
else {
}
return newNode;
}
static boolean exceedBoundary(int row, int col, int maxRow, int maxCol) {
if (row <0 || col <0 || row >=maxRow || col>=maxCol)
return false;
return true;
}
static boolean isSameState(TNode node1, TNode node2) {
boolean sameState = true;
for (int i = 0; i < node1.getRowSize(); i++) {
for (int j = 0; j < node1.getColSize(); j++) {
if (node1.getGrid()[i][j] != node2.getGrid()[i][j])
sameState = false;
break;
}
}
return sameState;
}
}
class TNode implements Comparable<TNode>{
@Override
public int compareTo(TNode o) {
// TODO Auto-generated method stub
return Integer.compare(this.getFValue(), o.getFValue());
}
private int[][] grid ; // the array of grid elements;
private int g = 0; // represent the estimate cost from Start Node to This Node; using depth of this node in solution tree;
private int h = 0; // the estimate cost from This Node to Target Node; using sum of manhattan distance of this Node to Target Node;
private TNode parent; // the parent node;
private int rowSize, colSize = 0;
private int emptyRow, emptyCol = 0;
private MoveAction forbiddenMove;
public TNode(int[][] arr, int row, int col) {
rowSize= row;
colSize = col;
grid = new int[rowSize][colSize];
for(int i=0; i<row; i++) {
for (int j =0 ; j< col; j++) {
grid[i][j] = arr[i][j];
if (0 == arr[i][j]) {
emptyRow = i;
emptyCol = j;
}
}
}
}
public int[][] getGrid() {
return grid;
}
public int[][] getGridClone(){
int[][] clone = new int[rowSize][colSize];
for(int i=0; i<rowSize; i++) {
for(int j=0; j<colSize; j++) {
clone[i][j] = grid[i][j];
}
}
return clone;
}
public int getRowSize() {
return rowSize;
}
public int getColSize() {
return colSize;
}
public int getEmptyRow() {
return emptyRow;
}
public void setEmptyRow(int emptyRow) {
this.emptyRow = emptyRow;
}
public int getEmptyCol() {
return emptyCol;
}
public void setEmptyCol(int emptyCol) {
this.emptyCol = emptyCol;
}
public int getG() {
return g;
}
public void setG(int g) {
this.g = g;
}
public int getH() {
return h;
}
public int getFValue() {
return this.getG() + this.getH();
}
/**
* calculate and return the h-value (the sum of distance of every node in grid) ;
* @param target the target state ;
*/
public void setH(int [][]target) {
int distance =0;
for(int i=0; i<rowSize; i++) {
for(int j=0; j<colSize; j++) {
if (grid[i][j] != target[i][j]) {
distance ++;
}
}
}
h = distance;
}
public TNode getParent() {
return parent;
}
public void setParent(TNode parent) {
this.parent = parent;
}
public MoveAction getForbiddenMove() {
return forbiddenMove;
}
public void setForbiddenMove(MoveAction forbiddenMove) {
this.forbiddenMove = forbiddenMove;
}
public void displayGrid() {
for(int i=0; i<rowSize; i++) {
for(int j=0; j<colSize; j++) {
System.out.printf("%d\t", grid[i][j]);
}
System.out.println();
}
System.out.printf("\tg: %d, h: %d, f: %d\r\n", g, h, this.getFValue());
}
}
class TNodeComparator implements Comparator<TNode>{
@Override
public int compare(TNode o1, TNode o2) {
// TODO Auto-generated method stub
return Integer.compare(o1.getFValue(), o2.getFValue());
}
}
测试:
给出起始状态:
2 8 0
1 6 3
7 5 4
目标状态:
1 2 3
8 0 4
7 6 5
运行过程如下:
Step 0:
2 8 0
1 6 3
7 5 4
g: 0, h: 8, f: 8
Step 1:
2 8 3
1 6 0
7 5 4
g: 1, h: 7, f: 8
Step 2:
2 8 3
1 6 4
7 5 0
g: 2, h: 6, f: 8
Step 3:
2 8 3
1 0 6
7 5 4
g: 2, h: 6, f: 8
Step 4:
2 8 3
1 6 4
7 0 5
g: 3, h: 5, f: 8
Step 5:
2 8 3
1 0 4
7 6 5
g: 4, h: 3, f: 7
Step 6:
2 0 8
1 6 3
7 5 4
g: 1, h: 8, f: 9
Step 7:
2 0 3
1 8 4
7 6 5
g: 5, h: 4, f: 9
Step 8:
2 8 3
0 1 4
7 6 5
g: 5, h: 4, f: 9
Step 9:
2 6 8
1 0 3
7 5 4
g: 2, h: 7, f: 9
Step 10:
0 2 8
1 6 3
7 5 4
g: 2, h: 7, f: 9
Step 11:
0 2 3
1 8 4
7 6 5
g: 6, h: 3, f: 9
Step 12:
1 2 8
0 6 3
7 5 4
g: 3, h: 6, f: 9
Step 13:
1 2 3
0 8 4
7 6 5
g: 7, h: 2, f: 9
Step 14:
1 2 3
8 0 4
7 6 5
g: 8, h: 0, f: 8
success