最优雅的数据结构之一——并查集DisjointSetUnion(基于Java语言实现)

介绍:

作用:

       一种树型的数据结构;

       用于处理一些不相交的集合的合并及查询

 

主要操作:

       合并(Union):把两个不相交的集合合并为一个集合;

       查询(Find):查询两个元素是否在同一个集合中。

 

实现方式:

1、数组实现(用于Quick Find

       - 查找时间复杂度O(1);

       - 合并时间复杂度O(n)。

2、数组+单链表实现(Quick Union)

       将每一个元素,看做是一个节点;将相同集合的节点串成一条链表。   

       - 初始化:把每个节点所在集合初始化为其自身(节点前置指针指向自己)——O(n);

       - 合并:O(1);

       - 查询:O(length),length是链表的长度。

3、数组+森林(本文将采取的实现方式)

       用有根树来表示集合:每棵树表示一个集合,树中的节点对应一个元素。

 

按秩(Rank)合并

       - 用于解决不断进行Union操作可能会导致森林退化成链表的情况,从而影响Find性能。

       - 用一个数组rank[]记录每个根节点对应的树的深度(如果不是根节点,其rank相当于以它作为根节点的子树的深度)。一开始,把所有元素的rank(秩)设为1。合并时比较两个根节点,把rank较小者往较大者上合并

 

路径压缩(Path Compression)

      在合并操作的查询过程中,把沿途的每个节点的父节点都设为根节点(递归实现,使并查集尽可能不是一定)是一个菊花图(只有两层的树的俗称))。


注意:

       路径压缩和按秩合并如果一起使用,查询和合并操作的时间复杂度都接近O(rank)(rank是森林的深度,严格意义上:O(log*n)→iterated logarithm),但是很可能会破坏rank的准确性(即数组中每个元素作为根节点的森林的深度可能不在准确,但每个集合的森林深度是准确的,并不影响并查集的使用)。

扩展:

 

       O(log*n)近乎是O(1)级别的(没O(1)快),但比O(log n)快很多。

 

实现:

DisjointSetUnion.java

package disjointsetunion;

public interface DisjointSetUnion<E> {
    void union(int index1,int index2);
    int find(int index);
}

DisjointSetUnionImpl.java

package disjointsetunion.impl;

import disjointsetunion.DisjointSetUnion;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * 并查集
 *
 * @param <E> 并查集中元素类型
 */
public class DisjointSetUnionImpl<E> implements DisjointSetUnion<E> {
    /**
     * 每个元素代表的Node节点
     */
    List<Node<E>> parent;
    /**
     * 每个Node节点的秩,即每个结点以自己为根节点的森林的深度
     * (每个元素的秩并不是准确的)
     */
    List<Integer> rank;

    public DisjointSetUnionImpl() {
        parent = new ArrayList<>();
        rank = new ArrayList<>();
    }

    public DisjointSetUnionImpl(List<List<E>> sets) {
        init(sets);
    }

    public boolean add(E e) {
        rank.add(1);
        return parent.add(new Node<>(parent.size(), e));
    }

    public Node<E> get(int index) {
        return parent.get(index);
    }

    public int indexOf(E element) {
        return parent.indexOf(new Node<>(0, element));
    }

    public int lastIndexOf(E element) {
        return parent.lastIndexOf(new Node<>(0, element));
    }


    public void clear() {
        parent.clear();
        rank.clear();
    }

    /**
     * 批量初始化方法
     */
    public void init(List<List<E>> sets) {
        parent = new ArrayList<>();
        rank = new ArrayList<>();
        for (List<E> set : sets) {
            if (set == null)
                continue;
            if (set.isEmpty())
                continue;
            Node<E> first = new Node<>(parent.size(), set.remove(0));
            parent.add(first);
            int root = rank.size();
            rank.add(1);
            for (E element : set) {
                parent.add(new Node<>(parent.size(), first, element));
                rank.add(1);
                rank.set(root, rank.get(root) + 1);
            }
        }
    }

    /**
     * 合并index1索引对应的Node结点的集合和index2索引对应的Node结点的集合
     */
    @Override
    public void union(int index1, int index2) {
        int rootIndex1 = find(index1);
        int rootIndex2 = find(index2);
        if (rootIndex1 == rootIndex2)
            return;
        /*
         * 按秩合并
         */
        int root1 = rank.get(rootIndex1);
        int root2 = rank.get(rootIndex2);
        if (root1 <= root2)
            parent.get(rootIndex1).prev = parent.get(rootIndex2);
        else
            parent.get(rootIndex2).prev = parent.get(rootIndex1);
        if (root1 == root2)
            rank.set(rootIndex2, rank.get(rootIndex2) + 1);
    }

    /**
     * 查询index1索引对应的Node结点所在的集合的根节点的索引值
     *
     * @return index1索引对应的Node结点所在的集合的根节点的索引值
     */
    @Override
    public int find(int index) {
        rangeCheck(index);
        Node<E> node = parent.get(index);
        /*
         * 路径压缩
         *
         */
        if (node != node.prev)
            node.prev = parent.get(find(node.prev.id));
        return node.prev.id;
    }

    /**
     * 判断两个元素是否在同一集合中
     *
     * @return true表示两个元素在同一个集合中
     */
    public boolean isConnected(int index1, int index2) {

        return find(index1) == find(index2);
    }

    @Override
    public String toString() {
        return parent.toString() + "\n" + rank.toString();
    }

    private void rangeCheck(int index) {
        int size = parent.size();
        if (index >= size)
            throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
    }

    /**
     * 封装并查集中每个元素的Node结点
     *
     * @param <E> 元素类型
     */
    private static class Node<E> {
        int id;
        E item;
        Node<E> prev;

        Node(int id, Node<E> prev, E element) {
            this.id = id;
            this.item = element;
            this.prev = prev;
        }

        Node(int id, E element) {
            this.id = id;
            this.item = element;
            this.prev = this;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "id=" + id +
                    ", item=" + item +
                    ", prev→" + prev.id +
                    '}';
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Node<?> node = (Node<?>) o;
            return Objects.equals(item, node.item);
        }

    }
}

测试:

Test.java

package disjointsetunion;

import disjointsetunion.impl.DisjointSetUnionImpl;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class Test {
    public static void main(String[] args) {
        ArrayList<List<String>> sets = new ArrayList<>();
        sets.add(Stream.of("1", "2", "3").collect(Collectors.toList()));
        sets.add(Stream.of("4", "5", "6").collect(Collectors.toList()));
        DisjointSetUnionImpl<String> disjointSetUnion = new DisjointSetUnionImpl<>(sets);
        System.out.println(disjointSetUnion);
        System.out.println(disjointSetUnion.find(0));
        disjointSetUnion.union(1, 5);
        System.out.println(disjointSetUnion);
        System.out.println(disjointSetUnion.find(2));
        System.out.println(disjointSetUnion);
        System.out.println(disjointSetUnion.isConnected(0, 1));
        System.out.println(disjointSetUnion.isConnected(0, 5));
    }
}

结果:

猜你喜欢

转载自blog.csdn.net/qq_40100414/article/details/118087230