This article uses the Python language to implement the A* algorithm.
The algorithm flow and principle will not be described in detail.
Code file structure:
point.py
import sys
class Point(object):
def __init__(self, x: int, y: int):
self.x = x
self.y = y
self.cost = sys.maxsize
self.parent = None
map.py
from typing import Tuple, List
from point import Point
class Map(object):
def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []):
self.width = width
self.height = height
self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles]
def is_obstacle(self, i: int, j: int):
for p in self.obstacles:
if i==p.x and j==p.y:
return True
return False
a_star.py
has a visual code, and finally generates a video. After the generation, the image generated in the middle can be deleted.
import os
import sys
import time
from typing import Tuple, List
from matplotlib.patches import Rectangle
import cv2
import glob
from point import Point
from map import Map
class AStar(object):
"""
A* algorithm
"""
def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]):
"""
initialise
:param map: map
:param origin: starting point coordinates
:param target: ending point coordinates
"""
self.map = map
self.origin = Point(x=origin[0], y=origin[1])
self.target = Point(x=target[0], y=target[1])
self.open_points = []
self.close_points = []
def _basic_cost(self, point: Point):
"""
basic cost from origin
"""
return abs(point.x - self.origin.x) + abs(point.y - self.origin.y)
def _heuristic_cost(self, point: Point):
"""
estimated cost to target
"""
return abs(point.x - self.target.x) + abs(point.y - self.target.y)
def _total_cost(self, point: Point):
"""
total cost
"""
return self._basic_cost(point) + self._heuristic_cost(point)
def _is_valid_point(self, x: int, y: int):
if x < 0 or y < 0:
return False
if x >= self.map.width or y >= self.map.height:
return False
if self.map.is_obstacle(x, y):
return False
return True
def _in_point_list(self, point: Point, points: List[Point]):
for p in points:
if point.x == p.x and point.y == p.y:
return True
return False
def _in_open_list(self, point: Point):
return self._in_point_list(point, self.open_points)
def _in_close_list(self, point: Point):
return self._in_point_list(point, self.close_points)
def run(self, ax, plt):
"""
run alogrithm and visualise
:param ax: matplotlib.axes._subplots.AxesSubplot
:param plt: matplotlib.pyplot
"""
tms = time.time()
self.origin.cost = 0
self.open_points.append(self.origin)
while True:
idx = self._select_from_open_list()
if idx < 0:
print("No path found, algorithm failed!")
return
point = self.open_points[idx]
rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan')
ax.add_patch(rectangle)
self._save_image(plt)
if point.x == self.target.x and point.y == self.target.y:
return self._build_path(point=point, tms=tms, ax=ax, plt=plt)
del self.open_points[idx]
self.close_points.append(point)
# neighbours
self._process_point(x=point.x - 1, y=point.y, parent=point)
self._process_point(x=point.x, y=point.y - 1, parent=point)
self._process_point(x=point.x + 1, y=point.y, parent=point)
self._process_point(x=point.x, y=point.y + 1, parent=point)
def _save_image(self, plt):
"""
save images to outputs folder
"""
millisecond = int(round(time.time() * 1000))
file_name = './outputs/' + str(millisecond) + '.png'
plt.savefig(file_name)
def _process_point(self, x: int, y: int, parent: Point):
"""
process current point
:param x: x coordinate
:param y: y coordinate
:param parent: current point's parent point
"""
# do nothing for invalid point
if not self._is_valid_point(x, y):
return
# do nothing for visited point
point = Point(x, y)
if self._in_close_list(point):
return
print("process point [{}, {}], cost: {}".format(point.x, point.y, point.cost))
if not self._in_open_list(point):
point.parent = parent
point.cost = self._total_cost(point)
self.open_points.append(point)
def _select_from_open_list(self) -> int:
"""
select the point with least cost from the open list
:return idx_select: the index of the selected point in the open list
"""
idx = 0
idx_select = -1
min_cost = sys.maxsize
for point in self.open_points:
cost = self._total_cost(point)
if cost < min_cost:
min_cost = cost
idx_select = idx
idx += 1
return idx_select
def _build_path(self, point: Point, tms: float, ax, plt):
"""
build the whole path after algorithm terminates
:param point: ending point
:param tms: start time
:param ax: matplotlib.axes._subplots.AxesSubplot
:param plt: matplotlib.pyplot
"""
# get whole path
path = []
while True:
path.insert(0, point)
if point.x == self.origin.x and point.y == self.origin.y:
break
else:
point = point.parent
# visualise
for p in path:
rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green')
ax.add_patch(rec)
plt.draw()
self._save_image(plt)
self._merge_video()
tme = time.time()
print("Algorithm finishes in {} s".format(int(tme - tms)))
def _merge_video(self):
"""
merge images to video
"""
# get image files
image_files = []
file_names = []
for file_name in glob.glob('./outputs/*.png'):
file_names.append(file_name)
image = cv2.imread(filename=file_name)
height, width, layers = image.shape
size = (width, height)
image_files.append(image)
# generate video
tm= time.time()
video_path = f'./outputs/{
round(tm)}.avi'
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
video = cv2.VideoWriter(video_path, fourcc, 5, size)
for image in image_files:
video.write(image)
video.release()
# delete original image files
for file in file_names:
os.remove(file)
main.py (main program)
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from map import Map
from a_star import AStar
""" map settings """
width, height = 10, 15
origin, target = (0, 0), (width - 1, height - 1)
obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [
(round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [
(round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))]
map_ = Map(width=width, height=height, obstacles=obstacles)
""" visual settings """
plt.figure(figsize=(5, 5))
ax = plt.gca()
ax.set_xlim([0, map_.width])
ax.set_ylim([0, map_.height])
for i in range(map_.width):
for j in range(map_.height):
if map_.is_obstacle(i, j):
rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray')
ax.add_patch(rectangle)
else:
rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red')
ax.add_patch(rectangle)
plt.axis('equal') # set equal scaling
plt.axis('off') # turn off axis lines and labels
plt.tight_layout()
""" algorithm """
a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1))
a_star.run(ax, plt)