基于图的图像分割

一、前言

最近一段时间在复现基于区域的对比度方法(region-based contrast 简称RC)的显著性检测。【原文请点击】 其中,遇到了问题。主要是用到了基于图的图像分割。
显著性检测RC算法:其前期工作就是利用Graph-Based Image Segmentation的分割算法。主要涉及了图网络的一些知识。【关于图网络请点击】

Graph-Based Image Segmentation是2004年由Felzenszwalb发表在IJCV上的一篇文章,主要介绍了一种基于图表示(graph-based)的图像分割方法。图像分割(Image Segmentation)的主要目的也就是将图像(image)分割成若干个特定的、具有独特性质的区域(region),然后从中提取出感兴趣的目标(object)。而图像区域之间的边界定义是图像分割算法的关键,论文给出了一种在图表示(graph-based)下图像区域之间边界的定义的判断标准(predicate),其分割算法就是利用这个判断标准(predicate)使用贪心选择(greedy decision)来产生分割(segmentation)。
该算法在时间效率上,基本上与图像(Image)的图(Graph)表示的边(edge)数量成线性关系,而图像的图表示的边与像素点成正比,也就说图像分割的时间效率与图像的像素点个数成线性关系。这个算法有一个非常重要的特性,它能保持低变化(low-variability)区域(region)的细节,同时能够忽略高变化(high-variability)区域(region)的细节。这个性质很特别也很重要,对图像有一个很好的分割效果(能够找出视觉上一致的区域,简单讲就是高变化区域有一个很好聚合(grouping),能够把它们分在同一个区域),这也是为什么那么多人引用该论文的原因吧。
【论文请点击】

无论在分割领域还是显著性检测上,都是能够捕捉视觉上重要的区域(perceptually important regions)。举个栗子:在下图1左侧有个红三角,下图2左侧有个更大的红三角,我们可以认为图2的红三角更显眼(更加的靠左侧),
在这里插入图片描述

二、算法理论

该论文主要有两个关键点:

  1. 图像(image)的图(graph)表示;
  2. 最小生成树(Minimun Spanning Tree)。
2.1 构建图

图像(image)的图表示是指将图像(image)表达成图论中的图(graph)。具体说来就是,把图像中的每一个像素点看成一个顶点 v i V v_i \in V (node或vertex),像素点之间的关系对(可以自己定义其具体关系,一般来说是指相邻关系)构成图的一条边 e i E e_i \in E ,这样就构建好了一个图 G = ( V , E ) G = (V,E)
相邻的两个像素点像素颜色值的差异构成边 ( v i , v j ) (v_i, v_j) 的权值 w ( v i , v j ) w(v_i, v_j) 。其中权值越小,表示像素点之间的相似度就越高,反之,相似度就越低。图每条边的权值是基于像素点之间的关系,可以是像素点之间的灰度值差,也可以是像素点(RGB)之间的距离:
灰度值素点之间距离: g r e y [ x 1 , y 1 ] g r e y [ x 2 , y 2 ] ∣\mathrm{grey} [x_1,y_1]−\mathrm{grey}[x_2,y_2]∣
像素点(RGB)之间的距离:
i m g G = ( i m g [ x 1 , y 1 , 0 ] i m g [ x 2 , y 2 , 0 ] ) 2 i m g B = ( i m g [ x 1 , y 1 , 1 ] i m g [ x 2 , y 2 , 1 ] ) 2 i m g R = ( i m g [ x 1 , y 1 , 2 ] i m g [ x 2 , y 2 , 2 ] ) 2 . \begin{matrix} \mathrm{imgG} = & (\mathrm{img}[x_1,y_1,0]−\mathrm{img}[x_2,y_2,0])^2 \\ \mathrm{imgB} = & (\mathrm{img}[x_1,y_1,1]−\mathrm{img}[x_2,y_2,1])^2\\ \mathrm{imgR} = & (\mathrm{img}[x_1,y_1,2]−\mathrm{img}[x_2,y_2,2])^2 \end{matrix}.

d i s t = ( i m g G + i m g B + i m g R ) dist = \sqrt(\mathrm{imgG}+\mathrm{imgB}+\mathrm{imgR})

2.2 分割图

将图像表达成图之后,接下来就是要如何分割这个图。将每个节点(像素点)看成单一的区域,然后进行合并。使用最小生成树方法合并像素点,然后构成一个个区域。大致意思就是讲图(Graph)简化,相似的区域在一个分支(Branch)上面(有一条最边连接),大大减少了图的边数。
图(Graph)分割是将 G = ( V , E ) G = (V,E) 分割成一系列不相交的部分 (component)C,每个C都构成一个子图G。这些子图的之间相互独立 (disjoint),主要是指它们之间没有公共的点。

2.3 算法的实现
  1. 分割区域(Component)的内部差(internal difference)。可以先假定图G已经简化成了最小生成树 MST,一个分割区域C 包含若干个顶点 ,顶点之间通过最小生成树的边连接。这个内部差就是指分割区域C中包含的最大边的权值。
    最大边的权值
  2. 分割区域(Component)之间的差别(Difference),是指两个分割区域之间顶点相互连接的最小边的权值。
    difference between two components
    如果两个分割部分之间没有边连接,定义 D i f ( C 1 , C 2 ) = Dif(C1,C2) = ∞ 。分割区域的差别可以有很多种定义的方式,可以选择中位置,或者其他的分位点(quantile,中位置是0.5分位点),但是选取其他的方式将会使得这个问题成为一个NP-hard问题。
  3. 分割区域(Component)边界的一种判断标准(predicate)。判断两个分割区域之间是否有明显的边界,主要是判断两个分割部分之间的差别Dif相对于和中较小的那个值MInt的大小,这里引入了一个阈值函数τ 来控制两者之间的差值。下面给出这个判断标准的定义:
    在这里插入图片描述
    其中,是指最小的分割内部差,其定义如下:
    minimum internal difference
    阈值函数 τ \tau 主要是为了更好的控制分割区域边界的定义。比较直观的理解,小分割区域的边界定义要强于大分割区域,否则可以将小分割区域继续合并形成大区域。在这里给出的阈值函数与区域的大小有关。
    在这里插入图片描述
    |C|是指分割部分顶点的个数(或者像素点个数),k是一个参数,可以根据不同的需求(主要根据图像的尺寸)进行调节。
2.4 几个分割概念
  1. 如果一个分割S,存在图(Graph)的分割区域之间,没有明显的边界,那么就说这个分割S“太精细”(too fine)。也就是说它们之间没有明显的分界线,硬要把它们分割开来的话,有点过头,也就是说分得太细。
  2. 如果一个分割S,存在一个合适的调整(refinement)S’使得S不是”太精细“,那么就说这个分割S”太粗糙“(too coarse)。简单来讲就是,分割程度的还不够(粒度还比较大),可以继续分割,这样刚开始的那个分割就是”太粗糙“(too coarse)了。

对于一个图graph来说,都存在一个分割S,既不是”太精细“(too fine)也不是”太粗糙“(too coarse)。

2.5 算法步骤
  1. 对于图G的所有边,按照权值进行排序(升序)
  2. S[0]是一个原始分割,相当于每个顶点当做是一个分割区域
  3. q = 1,2,…,m 重复3的操作(m为边的条数,也就是每次处理一条边)
  4. 根据上次 S [ q 1 ] S[q-1] 的构建。选择一条边o[q](vi,vj),如果vi和vj在分割的互不相交的区域中,比较这条边的权值与这两个分割区域之间的最小分割内部差MInt,如果o[q](vi,vj) < MInt,那么合并这两个区域,其他区域不变;如果否,什么都不做。
  5. 最后得到的就是所求的分割 S = S[m]

三、代码实现

C++实现代码请查看:http://cs.brown.edu/people/pfelzens/segment/

class Node:
    def __init__(self, parent, rank=0, size=1):
        self.parent = parent
        self.rank = rank
        self.size = size

    def __repr__(self):
        return '(parent=%s, rank=%s, size=%s)' % (self.parent, self.rank, self.size)

class Forest:
    def __init__(self, num_nodes):
        self.nodes = [Node(i) for i in range(num_nodes)]
        self.num_sets = num_nodes

    def size_of(self, i):
        return self.nodes[i].size

    def find(self, n):
        temp = n
        while temp != self.nodes[temp].parent:
            temp = self.nodes[temp].parent

        self.nodes[n].parent = temp
        return temp

    def merge(self, a, b):
        if self.nodes[a].rank > self.nodes[b].rank:
            self.nodes[b].parent = a
            self.nodes[a].size = self.nodes[a].size + self.nodes[b].size
        else:
            self.nodes[a].parent = b
            self.nodes[b].size = self.nodes[b].size + self.nodes[a].size

            if self.nodes[a].rank == self.nodes[b].rank:
                self.nodes[b].rank = self.nodes[b].rank + 1

        self.num_sets = self.num_sets - 1

    def print_nodes(self):
        for node in self.nodes:
            print(node)

def create_edge(img, width, x, y, x1, y1, diff):
    vertex_id = lambda x, y: y * width + x
    w = diff(img, x, y, x1, y1)
    return (vertex_id(x, y), vertex_id(x1, y1), w)

def build_graph(img, width, height, diff, neighborhood_8=False):
    graph_edges = []
    for y in range(height):
        for x in range(width):
            if x > 0:
                graph_edges.append(create_edge(img, width, x, y, x-1, y, diff))
            if y > 0:
                graph_edges.append(create_edge(img, width, x, y, x, y-1, diff))
            if neighborhood_8:
                if x > 0 and y > 0:
                    graph_edges.append(create_edge(img, width, x, y, x-1, y-1, diff))
                if x > 0 and y < height-1:
                    graph_edges.append(create_edge(img, width, x, y, x-1, y+1, diff))
    return graph_edges

def remove_small_components(forest, graph, min_size):
    for edge in graph:
        a = forest.find(edge[0])
        b = forest.find(edge[1])

        if a != b and (forest.size_of(a) < min_size or forest.size_of(b) < min_size):
            forest.merge(a, b)
    return  forest

# segment_graph(graph_edges, size[0]*size[1], K, min_comp_size, threshold)
def segment_graph(graph_edges, num_nodes, const, min_size, threshold_func):
    # Step 1: initialization
    # [(parent,rank,size) for i in range(num_nodes)]
    forest = Forest(num_nodes)

    weight = lambda edge: edge[2]
    sorted_graph = sorted(graph_edges, key=weight)
    threshold = [ threshold_func(1, const) for _ in range(num_nodes) ]

    # Step 2: merging
    for edge in sorted_graph:
        parent_a = forest.find(edge[0])
        parent_b = forest.find(edge[1])
        a_condition = weight(edge) <= threshold[parent_a]
        b_condition = weight(edge) <= threshold[parent_b]

        if parent_a != parent_b and a_condition and b_condition:
            forest.merge(parent_a, parent_b)
            a = forest.find(parent_a)
            threshold[a] = weight(edge) + threshold_func(forest.nodes[a].size, const)
    return remove_small_components(forest, sorted_graph, min_size)
numpy实现

主要使用numpy数组,去除了类的使用

def segment_graph(height_width, num, edges, c=20.0, min_size=200):
    u_array = np.zeros((height_width, 3), dtype=np.int32)
    u_array[:, 1] = np.array(range(height_width), dtype=np.int32)
    u_array[:, 2] = np.ones(height_width, dtype=np.int32)
    thresholds_copy = np.full(height_width,c,dtype=np.float32)
    loop_range = range(num)

    for i in loop_range:
        edge = edges[i]
        a = edge['a']
        while a!=u_array[a,1]:
            a =edge['a']= u_array[a, 1]
        b = edge['b']
        while b!=u_array[b,1]:
            b =edge['b']= u_array[b, 1]
        if a != b:
            if edge['w'] <= thresholds_copy[a] and edge['w'] <= thresholds_copy[b]:
                if (u_array[a, 0] > u_array[b, 0]):
                    u_array[b, 1] = a
                    u_array[a, 2] += u_array[b, 2]
                else:
                    u_array[a, 1] = b
                    u_array[b, 2] += u_array[a, 2]
                    if u_array[a, 0] == u_array[b, 0]:
                        u_array[b, 0] += 1
                while a != u_array[edge['a'], 1]:
                    a = edge['a'] = u_array[edge['a'], 1]
                thresholds_copy[a] = edge['w'] + c/u_array[a,2]
    for i in loop_range:
        while (edges[i]['a'] != u_array[edges[i]['a'],1]):
            edges[i]['a'] = u_array[edges[i]['a'],1]
        while (edges[i]['b'] != u_array[edges[i]['b'],1]):
            edges[i]['b'] = u_array[edges[i]['b'],1]
        if ((edges[i]['a'] != edges[i]['b']) and ((u_array[edges[i]['a'],2] < min_size) or (u_array[edges[i]['b'],2] < min_size))):
            if (u_array[edges[i]['a'], 0] > u_array[edges[i]['b'], 0]):
                u_array[edges[i]['b'], 1] = edges[i]['a']
                u_array[edges[i]['a'], 2] += u_array[edges[i]['b'], 2]
            else:
                u_array[edges[i]['a'], 1] = edges[i]['b']
                u_array[edges[i]['b'], 2] += u_array[edges[i]['a'], 2]
                if u_array[edges[i]['a'], 0] == u_array[edges[i]['b'], 0]:
                    u_array[edges[i]['b'], 0] += 1
    return u_array


# ===========================SegmentImage==========================================
# 像素间的差异度量
def diff(img3f, x1, y1, x2, y2):
    p1 = img3f[y1, x1]
    p2 = img3f[y2, x2]
    return np.sqrt(np.sum(np.power(p1 - p2, 2)))


def SegmentImage(smImg3f, c=20.0, min_size=200):
    height, width = smImg3f.shape[:2]
    edges = np.zeros((height-1)*(width-1)*4+(height-1)+(width-1),
                     dtype={'names': ['a', 'b','w'],'formats': ['i4', 'i4','f4']})
    num = 0
    width_range = range(width)
    height_range = range(height)
    for y in height_range:
        for x in width_range:
            if x < width - 1:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = y * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y)
                num += 1
            if y < height - 1:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y + 1) * width + x
                edges[num]['w'] = diff(smImg3f, x, y, x, y + 1)
                num += 1
            if (x < (width - 1)) and (y < (height - 1)):
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y + 1) * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y + 1)
                num += 1
            if (x < (width - 1)) and y > 0:
                edges[num]['a'] = y * width + x
                edges[num]['b'] = (y - 1) * width + (x + 1)
                edges[num]['w'] = diff(smImg3f, x, y, x + 1, y - 1)
                num += 1
    edges = np.sort(edges, order='w')
    u_array = segment_graph(width * height, num, edges, c=20.0, min_size=200)
    marker = {}
    imgIdx = np.zeros((smImg3f.shape[0], smImg3f.shape[1]), np.int32)
    idxNum = 0
    for y in height_range:
        for x in width_range:
            comp = y * width + x
            while (comp != u_array[comp, 1]):
                comp = u_array[comp, 1]
            if comp not in marker.keys():
                marker[comp] = idxNum
                idxNum += 1
            idx = marker[comp]
            imgIdx[y, x] = idx
    return idxNum, imgIdx

cython实现
主要用来加速该算法运算时间
【关于详情请点击查看】

特别鸣谢:
https://blog.csdn.net/u014796085/article/details/83449972
https://blog.csdn.net/surgewong/article/details/39008861
C++:https://blog.csdn.net/ttransposition/article/details/38024605

发布了386 篇原创文章 · 获赞 592 · 访问量 72万+

猜你喜欢

转载自blog.csdn.net/wsp_1138886114/article/details/103368278
今日推荐