介绍:
作用:
一种树型的数据结构;
用于处理一些不相交的集合的合并及查询。
主要操作:
合并(Union):把两个不相交的集合合并为一个集合;
查询(Find):查询两个元素是否在同一个集合中。
实现方式:
1、数组实现(用于Quick Find)
- 查找时间复杂度O(1);
- 合并时间复杂度O(n)。
2、数组+单链表实现(Quick Union)
将每一个元素,看做是一个节点;将相同集合的节点串成一条链表。
- 初始化:把每个节点所在集合初始化为其自身(节点前置指针指向自己)——O(n);
- 合并:O(1);
- 查询:O(length),length是链表的长度。
3、数组+森林(本文将采取的实现方式)
用有根树来表示集合:每棵树表示一个集合,树中的节点对应一个元素。
按秩(Rank)合并
- 用于解决不断进行Union操作可能会导致森林退化成链表的情况,从而影响Find性能。
- 用一个数组rank[]记录每个根节点对应的树的深度(如果不是根节点,其rank相当于以它作为根节点的子树的深度)。一开始,把所有元素的rank(秩)设为1。合并时比较两个根节点,把rank较小者往较大者上合并。
路径压缩(Path Compression)
在合并操作的查询过程中,把沿途的每个节点的父节点都设为根节点(递归实现,使并查集尽可能(不是一定)是一个菊花图(只有两层的树的俗称))。
注意:
路径压缩和按秩合并如果一起使用,查询和合并操作的时间复杂度都接近O(rank)(rank是森林的深度,严格意义上:O(log*n)→iterated logarithm),但是很可能会破坏rank的准确性(即数组中每个元素作为根节点的森林的深度可能不在准确,但每个集合的森林深度是准确的,并不影响并查集的使用)。
扩展:
O(log*n)近乎是O(1)级别的(没O(1)快),但比O(log n)快很多。
实现:
DisjointSetUnion.java
package disjointsetunion;
public interface DisjointSetUnion<E> {
void union(int index1,int index2);
int find(int index);
}
DisjointSetUnionImpl.java
package disjointsetunion.impl;
import disjointsetunion.DisjointSetUnion;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* 并查集
*
* @param <E> 并查集中元素类型
*/
public class DisjointSetUnionImpl<E> implements DisjointSetUnion<E> {
/**
* 每个元素代表的Node节点
*/
List<Node<E>> parent;
/**
* 每个Node节点的秩,即每个结点以自己为根节点的森林的深度
* (每个元素的秩并不是准确的)
*/
List<Integer> rank;
public DisjointSetUnionImpl() {
parent = new ArrayList<>();
rank = new ArrayList<>();
}
public DisjointSetUnionImpl(List<List<E>> sets) {
init(sets);
}
public boolean add(E e) {
rank.add(1);
return parent.add(new Node<>(parent.size(), e));
}
public Node<E> get(int index) {
return parent.get(index);
}
public int indexOf(E element) {
return parent.indexOf(new Node<>(0, element));
}
public int lastIndexOf(E element) {
return parent.lastIndexOf(new Node<>(0, element));
}
public void clear() {
parent.clear();
rank.clear();
}
/**
* 批量初始化方法
*/
public void init(List<List<E>> sets) {
parent = new ArrayList<>();
rank = new ArrayList<>();
for (List<E> set : sets) {
if (set == null)
continue;
if (set.isEmpty())
continue;
Node<E> first = new Node<>(parent.size(), set.remove(0));
parent.add(first);
int root = rank.size();
rank.add(1);
for (E element : set) {
parent.add(new Node<>(parent.size(), first, element));
rank.add(1);
rank.set(root, rank.get(root) + 1);
}
}
}
/**
* 合并index1索引对应的Node结点的集合和index2索引对应的Node结点的集合
*/
@Override
public void union(int index1, int index2) {
int rootIndex1 = find(index1);
int rootIndex2 = find(index2);
if (rootIndex1 == rootIndex2)
return;
/*
* 按秩合并
*/
int root1 = rank.get(rootIndex1);
int root2 = rank.get(rootIndex2);
if (root1 <= root2)
parent.get(rootIndex1).prev = parent.get(rootIndex2);
else
parent.get(rootIndex2).prev = parent.get(rootIndex1);
if (root1 == root2)
rank.set(rootIndex2, rank.get(rootIndex2) + 1);
}
/**
* 查询index1索引对应的Node结点所在的集合的根节点的索引值
*
* @return index1索引对应的Node结点所在的集合的根节点的索引值
*/
@Override
public int find(int index) {
rangeCheck(index);
Node<E> node = parent.get(index);
/*
* 路径压缩
*
*/
if (node != node.prev)
node.prev = parent.get(find(node.prev.id));
return node.prev.id;
}
/**
* 判断两个元素是否在同一集合中
*
* @return true表示两个元素在同一个集合中
*/
public boolean isConnected(int index1, int index2) {
return find(index1) == find(index2);
}
@Override
public String toString() {
return parent.toString() + "\n" + rank.toString();
}
private void rangeCheck(int index) {
int size = parent.size();
if (index >= size)
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
/**
* 封装并查集中每个元素的Node结点
*
* @param <E> 元素类型
*/
private static class Node<E> {
int id;
E item;
Node<E> prev;
Node(int id, Node<E> prev, E element) {
this.id = id;
this.item = element;
this.prev = prev;
}
Node(int id, E element) {
this.id = id;
this.item = element;
this.prev = this;
}
@Override
public String toString() {
return "Node{" +
"id=" + id +
", item=" + item +
", prev→" + prev.id +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Node<?> node = (Node<?>) o;
return Objects.equals(item, node.item);
}
}
}
测试:
Test.java
package disjointsetunion;
import disjointsetunion.impl.DisjointSetUnionImpl;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class Test {
public static void main(String[] args) {
ArrayList<List<String>> sets = new ArrayList<>();
sets.add(Stream.of("1", "2", "3").collect(Collectors.toList()));
sets.add(Stream.of("4", "5", "6").collect(Collectors.toList()));
DisjointSetUnionImpl<String> disjointSetUnion = new DisjointSetUnionImpl<>(sets);
System.out.println(disjointSetUnion);
System.out.println(disjointSetUnion.find(0));
disjointSetUnion.union(1, 5);
System.out.println(disjointSetUnion);
System.out.println(disjointSetUnion.find(2));
System.out.println(disjointSetUnion);
System.out.println(disjointSetUnion.isConnected(0, 1));
System.out.println(disjointSetUnion.isConnected(0, 5));
}
}