Python implementation of the A* algorithm

This article uses the Python language to implement the A* algorithm.
The algorithm flow and principle will not be described in detail.

Code file structure:
file structure

point.py

import sys


class Point(object):
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y

        self.cost = sys.maxsize

        self.parent = None


map.py

from typing import Tuple, List

from point import Point


class Map(object):
    def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []):
        self.width = width
        self.height = height

        self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles]

    def is_obstacle(self, i: int, j: int):
        for p in self.obstacles:
            if i==p.x and j==p.y:
                return True
            
        return False


a_star.py
has a visual code, and finally generates a video. After the generation, the image generated in the middle can be deleted.

import os
import sys
import time
from typing import Tuple, List

from matplotlib.patches import Rectangle
import cv2
import glob

from point import Point
from map import Map


class AStar(object):
    """
    A* algorithm
    """
    def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]):
        """
        initialise
        
        :param map:  map
        :param origin:  starting point coordinates
        :param target:  ending point coordinates
        """

        self.map = map

        self.origin = Point(x=origin[0], y=origin[1])
        self.target = Point(x=target[0], y=target[1])

        self.open_points = []
        self.close_points = []

    def _basic_cost(self, point: Point):
        """
        basic cost from origin
        """
        return abs(point.x - self.origin.x) + abs(point.y - self.origin.y)

    def _heuristic_cost(self, point: Point):
        """
        estimated cost to target
        """
        return abs(point.x - self.target.x) + abs(point.y - self.target.y)

    def _total_cost(self, point: Point):
        """
        total cost
        """
        return self._basic_cost(point) + self._heuristic_cost(point)

    def _is_valid_point(self, x: int, y: int):
        if x < 0 or y < 0:
            return False
        if x >= self.map.width or y >= self.map.height:
            return False
        if self.map.is_obstacle(x, y):
            return False
        
        return True

    def _in_point_list(self, point: Point, points: List[Point]):
        for p in points:
            if point.x == p.x and point.y == p.y:
                return True
            
        return False

    def _in_open_list(self, point: Point):
        return self._in_point_list(point, self.open_points)

    def _in_close_list(self, point: Point):
        return self._in_point_list(point, self.close_points)
    
    def run(self, ax, plt):
        """
        run alogrithm and visualise
        
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        tms = time.time()

        self.origin.cost = 0
        self.open_points.append(self.origin)
        while True:
            idx = self._select_from_open_list()
            if idx < 0:
                print("No path found, algorithm failed!")
                return
            point = self.open_points[idx]

            rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan')
            ax.add_patch(rectangle)
            self._save_image(plt)

            if point.x == self.target.x and point.y == self.target.y:
                return self._build_path(point=point, tms=tms, ax=ax, plt=plt)

            del self.open_points[idx]
            self.close_points.append(point)

            # neighbours
            self._process_point(x=point.x - 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y - 1, parent=point)
            self._process_point(x=point.x + 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y + 1, parent=point)

    def _save_image(self, plt):
        """
        save images to outputs folder
        """

        millisecond = int(round(time.time() * 1000))
        file_name = './outputs/' + str(millisecond) + '.png'
        plt.savefig(file_name)

    def _process_point(self, x: int, y: int, parent: Point):
        """
        process current point
        
        :param x:  x coordinate
        :param y:  y coordinate
        :param parent:  current point's parent point
        """

        # do nothing for invalid point
        if not self._is_valid_point(x, y):
            return
        
        # do nothing for visited point
        point = Point(x, y)
        if self._in_close_list(point):
            return
        
        print("process point [{}, {}],  cost: {}".format(point.x, point.y, point.cost))
        if not self._in_open_list(point):
            point.parent = parent
            point.cost = self._total_cost(point)
            self.open_points.append(point)

    def _select_from_open_list(self) -> int:
        """
        select the point with least cost from the open list

        :return idx_select:  the index of the selected point in the open list
        """

        idx = 0
        idx_select = -1
        min_cost = sys.maxsize
        for point in self.open_points:
            cost = self._total_cost(point)
            if cost < min_cost:
                min_cost = cost
                idx_select = idx
            idx += 1

        return idx_select

    def _build_path(self, point: Point, tms: float, ax, plt):
        """
        build the whole path after algorithm terminates
        
        :param point:  ending point
        :param tms:  start time
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        # get whole path
        path = []
        while True:
            path.insert(0, point)
            if point.x == self.origin.x and point.y == self.origin.y:
                break
            else:
                point = point.parent

        # visualise
        for p in path:
            rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green')
            ax.add_patch(rec)
            plt.draw()
            self._save_image(plt)

        self._merge_video()

        tme = time.time()
        print("Algorithm finishes in {} s".format(int(tme - tms)))

    def _merge_video(self):
        """
        merge images to video
        """

        # get image files
        image_files = []
        file_names = []
        for file_name in glob.glob('./outputs/*.png'):
            file_names.append(file_name)
            image = cv2.imread(filename=file_name)
            height, width, layers = image.shape
            size = (width, height)
            image_files.append(image)

        # generate video
        tm= time.time()
        video_path = f'./outputs/{
      
      round(tm)}.avi'
        fourcc = cv2.VideoWriter_fourcc(*'DIVX')
        video = cv2.VideoWriter(video_path, fourcc, 5, size)
        for image in image_files:
            video.write(image)
        video.release()

        # delete original image files
        for file in file_names:
            os.remove(file)


main.py (main program)

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

from map import Map
from a_star import AStar


""" map settings """

width, height = 10, 15

origin, target = (0, 0), (width - 1, height - 1)
obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [
    (round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [
    (round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))]

map_ = Map(width=width, height=height, obstacles=obstacles)


""" visual settings """

plt.figure(figsize=(5, 5))
ax = plt.gca()
ax.set_xlim([0, map_.width])
ax.set_ylim([0, map_.height])

for i in range(map_.width):
    for j in range(map_.height):
        if map_.is_obstacle(i, j):
            rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray')
            ax.add_patch(rectangle)
        else:
            rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white')
            ax.add_patch(rectangle)

rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red')
ax.add_patch(rectangle)

plt.axis('equal')  # set equal scaling
plt.axis('off')  # turn off axis lines and labels
plt.tight_layout()


""" algorithm """

a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1))
a_star.run(ax, plt)

Guess you like

Origin blog.csdn.net/Zhang_0702_China/article/details/129893903