摘要
本文是研究生课程图像处理期末作业,内容是了解并入门超像素算法原理,主要介绍了超像素的评测标准,经典算法 SLIC,讨论了 SLIC 算法中的不足之处,以及 SLIC 的两个有效的改进算法 SEEDS 和 ETPS。
文章内容见我的码云
python3 代码
运算速度慢但是便于理清 SLIC 算法的参考实现
import math
from skimage import io, color
import numpy as np
from tqdm import trange
from tqdm import tqdm
class Cluster(object):
cluster_index = 1
def __init__(self, h, w, l=0, a=0, b=0):
self.update(h, w, l, a, b)
self.pixels = []
self.no = self.cluster_index
Cluster.cluster_index += 1
def update(self, h, w, l, a, b):
self.h = h
self.w = w
self.l = l
self.a = a
self.b = b
def __str__(self):
return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)
def __repr__(self):
return self.__str__()
class SLICProcessor(object):
@staticmethod
def open_image(path):
"""
Return:
3D array, row col [LAB]
"""
rgb = io.imread(path)
lab_arr = color.rgb2lab(rgb)
return lab_arr
@staticmethod
def save_lab_image(path, lab_arr):
"""
Convert the array to RBG, then save the image
:param path:
:param lab_arr:
:return:
"""
rgb_arr = color.lab2rgb(lab_arr)
io.imsave(path, rgb_arr)
def make_cluster(self, h, w):
h = int(h)
w = int(w)
return Cluster(h, w,
self.data[h][w][0],
self.data[h][w][1],
self.data[h][w][2])
def __init__(self, filename, K, M):
self.K = K
self.M = M
self.data = self.open_image(filename)
self.image_height = self.data.shape[0]
self.image_width = self.data.shape[1]
self.N = self.image_height * self.image_width
self.S = int(math.sqrt(self.N / self.K))
self.clusters = []
self.label = {}
self.dis = np.full((self.image_height, self.image_width), np.inf)
def init_clusters(self):
h = self.S // 2
w = self.S // 2
while h < self.image_height:
while w < self.image_width:
self.clusters.append(self.make_cluster(h, w))
w += self.S
w = self.S // 2
h += self.S
def get_gradient(self, h, w):
if w + 1 >= self.image_width:
w = self.image_width - 2
if h + 1 >= self.image_height:
h = self.image_height - 2
gradient = self.data[h + 1][w + 1][0] - self.data[h][w][0] + \
self.data[h + 1][w + 1][1] - self.data[h][w][1] + \
self.data[h + 1][w + 1][2] - self.data[h][w][2]
return gradient
def move_clusters(self):
for cluster in self.clusters:
cluster_gradient = self.get_gradient(cluster.h, cluster.w)
for dh in range(-1, 2):
for dw in range(-1, 2):
_h = cluster.h + dh
_w = cluster.w + dw
new_gradient = self.get_gradient(_h, _w)
if new_gradient < cluster_gradient:
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
cluster_gradient = new_gradient
def assignment(self):
for cluster in tqdm(self.clusters):
for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
if h < 0 or h >= self.image_height: continue
for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
if w < 0 or w >= self.image_width: continue
L, A, B = self.data[h][w]
Dc = math.sqrt(
math.pow(L - cluster.l, 2) +
math.pow(A - cluster.a, 2) +
math.pow(B - cluster.b, 2))
Ds = math.sqrt(
math.pow(h - cluster.h, 2) +
math.pow(w - cluster.w, 2))
D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2))
if D < self.dis[h][w]:
if (h, w) not in self.label:
self.label[(h, w)] = cluster
cluster.pixels.append((h, w))
else:
self.label[(h, w)].pixels.remove((h, w))
self.label[(h, w)] = cluster
cluster.pixels.append((h, w))
self.dis[h][w] = D
def update_cluster(self):
for cluster in self.clusters:
sum_h = sum_w = number = 0
for p in cluster.pixels:
sum_h += p[0]
sum_w += p[1]
number += 1
_h = int(sum_h / number)
_w = int(sum_w / number)
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
def save_current_image(self, name):
image_arr = np.copy(self.data)
for cluster in self.clusters:
for p in cluster.pixels:
image_arr[p[0]][p[1]][0] = cluster.l
image_arr[p[0]][p[1]][1] = cluster.a
image_arr[p[0]][p[1]][2] = cluster.b
image_arr[cluster.h][cluster.w][0] = 0
image_arr[cluster.h][cluster.w][1] = 0
image_arr[cluster.h][cluster.w][2] = 0
self.save_lab_image(name, image_arr)
def iterate_10times(self):
self.init_clusters()
self.move_clusters()
for i in trange(10):
self.assignment()
self.update_cluster()
name = 'lenna_M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
self.save_current_image(name)
if __name__ == '__main__':
p = SLICProcessor('kemomimi.png', 500, 40)
p.iterate_10times()
将上面缓慢的循环体转化成 numpy 高速矩阵运算的参考实现,很有 python 优化计算的学习意义。
作者原本是有一段合并不连续的超像素的函数( SLIC 可选的后期处理),发现并没有用到,就去掉了。
# Aleena Watson
# Final Project - Computer Vision Simon Niklaus
# Winter 2018 - PSU
import numpy as np
import sys
from skimage import io, color
import tqdm
# using algorithm in 3.2 apply image gradients as computed in eq2:
# G(x,y) = ||I(x+1,y) - I(x-1,y)||^2+ ||I(x,y+1) - I(x,y-1)||^2
# SLIC implements a special case of k-means clustering algorithm.
# Was recommended to use an off the shelf algorithm for clustering but
# because this algorithm is based on this special case of k-means,
# I kept this implementation to stay true to the algorithm.
def generate_pixels():
indnp = np.mgrid[0:SLIC_height,0:SLIC_width].swapaxes(0,2).swapaxes(0,1)
for i in tqdm.tqdm(range(SLIC_ITERATIONS)):
SLIC_distances = 1 * np.ones(img.shape[:2])
for j in range(SLIC_centers.shape[0]):
x_low, x_high = int(SLIC_centers[j][3] - step), int(SLIC_centers[j][3] + step)
y_low, y_high = int(SLIC_centers[j][4] - step), int(SLIC_centers[j][4] + step)
if x_low <= 0:
x_low = 0
#end
if x_high > SLIC_width:
x_high = SLIC_width
#end
if y_low <=0:
y_low = 0
#end
if y_high > SLIC_height:
y_high = SLIC_height
#end
cropimg = SLIC_labimg[y_low : y_high , x_low : x_high]
color_diff = cropimg - SLIC_labimg[int(SLIC_centers[j][4]), int(SLIC_centers[j][3])]
color_distance = np.sqrt(np.sum(np.square(color_diff), axis=2))
yy, xx = np.ogrid[y_low : y_high, x_low : x_high]
pixdist = ((yy-SLIC_centers[j][4])**2 + (xx-SLIC_centers[j][3])**2)**0.5
# SLIC_m is "m" in the paper, (m/S)*dxy
dist = ((color_distance/SLIC_m)**2 + (pixdist/step)**2)**0.5
distance_crop = SLIC_distances[y_low : y_high, x_low : x_high]
idx = dist < distance_crop
distance_crop[idx] = dist[idx]
SLIC_distances[y_low : y_high, x_low : x_high] = distance_crop
SLIC_clusters[y_low : y_high, x_low : x_high][idx] = j
#end
for k in range(len(SLIC_centers)):
idx = (SLIC_clusters == k)
colornp = SLIC_labimg[idx]
distnp = indnp[idx]
SLIC_centers[k][0:3] = np.sum(colornp, axis=0)
sumy, sumx = np.sum(distnp, axis=0)
SLIC_centers[k][3:] = sumx, sumy
SLIC_centers[k] /= np.sum(idx)
#end
#end
#end
def display_contours(color):
rgb_img = img.copy()
is_taken = np.zeros(img.shape[:2], np.bool)
contours = []
for i in range(SLIC_width):
for j in range(SLIC_height):
nr_p = 0
for dx, dy in [(-1,0), (-1,-1), (0,-1), (1,-1), (1,0), (1,1), (0,1), (-1,1)]:
x = i + dx
y = j + dy
if x>=0 and x < SLIC_width and y>=0 and y < SLIC_height:
if is_taken[y, x] == False and SLIC_clusters[j, i] != SLIC_clusters[y, x]:
nr_p += 1
#end
#end
#end
if nr_p >= 2:
is_taken[j, i] = True
contours.append([j, i])
#end
#end
#end
for i in range(len(contours)):
rgb_img[contours[i][0], contours[i][1]] = color
# for k in range(SLIC_centers.shape[0]):
# i,j = SLIC_centers[k][-2:]
# img[int(i),int(j)] = (0,0,0)
#end
io.imsave("SLIC_contours.jpg", rgb_img)
return rgb_img
#end
def display_center():
'''
将超像素用聚类中心颜色代替
'''
import matplotlib.pyplot as plt
lab_img = np.zeros([SLIC_height,SLIC_width,3]).astype(np.float64)
for i in range(SLIC_width):
for j in range(SLIC_height):
k = int(SLIC_clusters[j, i])
lab_img[j,i] = SLIC_centers[k][0:3]
rgb_img = color.lab2rgb(lab_img)
io.imsave("SLIC_centers.jpg",rgb_img)
return (rgb_img*255).astype(np.uint8)
def find_local_minimum(center):
min_grad = 1
loc_min = center
for i in range(center[0] - 1, center[0] + 2):
for j in range(center[1] - 1, center[1] + 2):
c1 = SLIC_labimg[j+1, i]
c2 = SLIC_labimg[j, i+1]
c3 = SLIC_labimg[j, i]
if ((c1[0] - c3[0])**2)**0.5 + ((c2[0] - c3[0])**2)**0.5 < min_grad:
min_grad = abs(c1[0] - c3[0]) + abs(c2[0] - c3[0])
loc_min = [i, j]
#end
#end
#end
return loc_min
#end
def calculate_centers():
centers = []
for i in range(step, SLIC_width - int(step/2), step):
for j in range(step, SLIC_height - int(step/2), step):
nc = find_local_minimum(center=(i, j))
color = SLIC_labimg[nc[1], nc[0]]
center = [color[0], color[1], color[2], nc[0], nc[1]]
centers.append(center)
#end
#end
return centers
#end
# global variables
img = io.imread(sys.argv[1])
print(img.max(),img.min())
step = int((img.shape[0]*img.shape[1]/int(sys.argv[2]))**0.5)
SLIC_m = int(sys.argv[3])
SLIC_ITERATIONS = 4
SLIC_height, SLIC_width = img.shape[:2]
SLIC_labimg = color.rgb2lab(img)
SLIC_distances = 1 * np.ones(img.shape[:2])
SLIC_clusters = -1 * SLIC_distances
SLIC_center_counts = np.zeros(len(calculate_centers()))
SLIC_centers = np.array(calculate_centers())
# main
generate_pixels()
calculate_centers()
img_contours = display_contours([0.0, 0.0, 0.0])
img_center = display_center()
print(img,img_center,img_contours)
result = np.hstack([img,img_contours,img_center])
io.imsave("my_slic.jpg",result)