关于A星算法的代码实现在各大厂人工智能岗位面试中均出现过,需要面试者有较高的手撕代码能力。
这篇文章即在python中从零复现A星寻路算法,并利用matplot将计算所得路线可视化画出,适合新手练习代码能力,也适合准备面试的朋友复习。
本文不注重讲解A星算法的具体内容,着重于通过代码的实现!
*参考资料:https://www.laurentluce.com/posts/solving-mazes-using-python-simple-recursivity-and-a-search/
Ⅱ.开始代码寿司
为了简化,这里不考虑路途权重问题,每个方块只考虑是否可达(即是否为墙)。
我们这里定义两个类,一个为cell,用于记录每个cell的信息,如坐标、是否可达及gfh等信息(若需要可以记录权重);另一个为AStar类,用于实现算法。
一、Cell类很简单,第一个cell类定义如下:
class cell(object): def __init__(self, x, y, reachable:bool) -> None: self.x = x self.y = y self.isReachable = reachable self.g = 0 self.h = 0 self.f = self.g + self.h self.parent = None self.weight = 0 def __lt__(self, other): return self.f < other.f
其中weights作为可扩展项,此算法测试中不考虑每个cell权重大小来更新,只考虑该cell是否可达来更新。
二、AStar类第二个AStar类需要实现的功能很多,我们先一一列出来,然后再逐一讲解填充。
class AStar(object): def __init__(self) -> None: ''' 初始化一些数据结构:堆、集合帮助实现算法 ''' pass def init_grid(self, width, height, cells, walls): ''' 初始化地图,导入walls ''' pass def get_cell(self, x, y): ''' 因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell ''' pass def caculate_one_way(self, start, end): ''' 在地图确定不变的情况下,每次传入不同的起点和终点,计算返回路径 ''' pass def caculate_heuristic(self, cell): ''' 计算启发式距离h值 ''' pass def get_adjacent_cell(self, cell): ''' 返回cell周围的cell,这里的周围指八个方向 ''' pass def get_updated(self, adj, cell): ''' 用于每次更新cell信息 ''' pass def save_path(self): ''' 保存计算路径 ''' pass def solve(self): ''' 代码核心,实现逻辑 ''' pass
接下来逐一填充讲解
①初始化数据结构了解A星算法的都知道其实现需要一些数据结构的帮忙,这一版本的实现需要用到堆、列表与集合
import heapqclass AStar(object): def __init__(self) -> None: ''' 初始化一些数据结构:堆、集合帮助实现算法 ''' self.closed = set() self.open = [] heapq.heapify(self.open) self.cells = []
②初始化地图导入宽度、高度及cells
def init_grid(self, width, height, walls): ''' 初始化地图,导入cells ''' self.grid_width = width self.grid_height = height for i in range(self.grid_height): for j in range(self.grid_width): if (i,j) in walls: reachable = false else: reachable = true self.cells.append(cell(i, j, reachable))
③重定位cellsdef get_cell(self, x, y): ''' 因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell ''' return self.cells[ x*self.grid_width + y ]
④设置起终点def caculate_one_way(self, start, end): ''' 在地图确定不变的情况下,每次传入不同的起点和终点 ''' self.start = self.get_cell(*start) self.end = self.get_cell(*end)
⑤计算启发式距离def caculate_heuristic(self, cell): ''' 计算启发式距离h值,这里采用曼哈顿距离 ''' return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) )
⑥计算临近点 def get_adjacent_cell(self, cell): ''' 返回cell周围的cell,这里的周围指八个方向 ''' adj_cells = [] for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]: x2 = cell.x + dx y2 = cell.y + dy if x2>0 and x2
def get_updated(self, adj, cell): ''' 用于每次更新cell信息 ''' adj.g = cell.g + 10 adj.parent = cell adj.h = self.caculate_heuristic(adj) adj.f = self.g + self.h
⑧保存路径def save_path(self): ''' 保存计算路径 ''' cell = self.end path = [(cell.x, cell.y)] while cell.parent is not self.start: cell = cell.parent path.append((cell.x, cell.y)) path.append((self.start.x, self.start.y)) path.reverse() return path
⑨逻辑实现借助上图,我们很容易写出python对应的逻辑代码:
def solve(self): ''' 代码核心,实现逻辑 ''' heapq.heappush(self.openlist, (self.start.f, self.start)) while len(self.openlist): f, cell = heapq.heappop(self.openlist) self.closed.add(cell) if cell is self.end: return self.save_path() adj_cells = self.get_adjacent_cell(cell) for adj_cell in adj_cells: if adj_cell.isReachable and adj_cell not in self.closed: if ( adj_cell.f, adj_cell ) in self.openlist: if adj_cell.g > cell.g + 10: self.get_updated(adj_cell, cell) else: self.get_updated(adj_cell, cell) heapq.heappush(self.openlist, ( adj_cell.f, adj_cell )) raise RuntimeError("A* failed to find a solution")
三、可视化借助matplotlib将结果可视化
def draw_result(result_path, walls, start, end): plt.plot([v[0] for v in result_path], [v[1] for v in result_path]) plt.plot([v[0] for v in result_path], [v[1] for v in result_path], 'o', color='lightblue') plt.plot([start[0], end[0]], [start[1], end[1]], 'o', color='red') plt.plot([barrier[0] for barrier in walls ], [barrier[1] for barrier in walls], 's', color='m') plt.xlim(-1, 8) plt.ylim(-1, 8) plt.show()
四、测试若路径不可行会报错
raise RuntimeError("A* failed to find a solution")
下面测试可行路径
if __name__ == '__main__': a = AStar() walls = ((2, 5), (2, 6), (3, 6), (4, 6), (5, 6), (5, 5), (5, 4), (5, 3), (5, 2), (4, 2), (3, 2), (7, 1), (6, 4), (1, 5), (7, 6)) a.init_grid(8, 8, walls) a.caculate_one_way((0, 0), (7, 7)) path = a.solve() print(path) draw_result(path,walls,(0, 0), (7,7))
结果如下,紫色代表墙体,蓝色为计算路径
Ⅲ.完整代码
from __future__ import print_functionimport heapqimport matplotlib.pyplot as pltimport unittestclass cell(object): def __init__(self, x, y, reachable:bool) -> None: self.x = x self.y = y self.isReachable = reachable self.g = 0 self.h = 0 self.f = self.g + self.h self.parent = None self.weight = 0 def __lt__(self, other): return self.f < other.f class AStar(object): def __init__(self) -> None: ''' 初始化一些数据结构:堆、集合帮助实现算法 ''' self.closed = set() self.openlist = [] heapq.heapify(self.openlist) self.cells = [] def init_grid(self, width, height, walls): ''' 初始化地图,导入cells ''' self.grid_width = width self.grid_height = height for i in range(self.grid_height): for j in range(self.grid_width): if (i,j) in walls: reachable = False else: reachable = True self.cells.append(cell(i, j, reachable)) def get_cell(self, x, y): ''' 因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell ''' return self.cells[ x*self.grid_width + y ] def caculate_one_way(self, start, end): ''' 在地图确定不变的情况下,每次传入不同的起点和终点 ''' self.start = self.get_cell(*start) self.end = self.get_cell(*end) def caculate_heuristic(self, cell): ''' 计算启发式距离h值,这里采用曼哈顿距离 ''' return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) ) def get_adjacent_cell(self, cell): ''' 返回cell周围的cell,这里的周围指八个方向 ''' adj_cells = [] for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]: x2 = cell.x + dx y2 = cell.y + dy if x2>0 and x2
END
祝大家学习面试顺利,OFFER多多。