import matplotlib.pyplot as plt
import matplotlib as mpl
colors = ['white', 'blue', 'cyan', 'Lime','yellow']
# bounds = [0,1,2,3,4]
cmap = mpl.colors.ListedColormap(colors)
# norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
plt.imshow(a1[num1, :, :], cmap=cmap)