数据结构-线段树

版权声明:本文为博主原创文章,转载请注明出处 https://blog.csdn.net/love905661433/article/details/83006902

线段树

特点

  • 线段树不是完全二叉树
  • 线段树是平衡二叉树

对于给定区间, 支持更新和查询操作 :

  • 更新 : 更新区间中的一个元素或者一个区间的值
  • 查询 : 查询一个区间[i, j]的最大值, 最小值, 或者区间数字和

使用数组构建线段树

如下图所示数组A, 以求和为例, 根节点A[0-7]存放的就是A[0-3]节点和A[4-7]节点之和, 下面的每个节点存放的值都是该节点对应左右孩子节点的和, 这样就用数组构建出了一个线段树,

1539129786062

  • 可以把线段树当成满二叉树进行处理
  • 对于有n个元素的区间, 数组只需要4n的空间就可以完全存储整颗线段树, 4n的空间会有部分浪费, 最坏的情况可能会有接近2n的空间被浪费
  • 不考虑添加元素

线段树区间查找

如下图所示, 线段树查找步骤如下:

  1. 在0-7的区间内查找2-5, 左右子树都包含部分, 所以在左侧查询2-3, 右侧查询4-5
  2. 继续在0-3的区间查找2-3, 在4-7的区间查找4-5
  3. 将查找到的2-3区间和4-5区间进行一次merge操作, 得到的就是2-5的区间
    1539217298633

线段树更新

线段树更新的方法也很简单, 更新对应位置的值之后, 包含该位置的区间的值也都要进行更新

线段树代码实现

线段树完整代码实现如下 :

package tree.segment;

/**
 * 使用数组实现线段树
 * @author 七夜雪
 *
 * @param <E>
 */
public class SegmentTree<E> {
	
	private Merger<E> merger;
	private E[] tree;
	private E[] data;
	
	@SuppressWarnings("unchecked")
	public SegmentTree (E[] arr, Merger<E> merger){
		this.merger = merger;
		// java中无法直接使用new E[arr.length];这种方式创建泛型数组
		data = (E[])new Object[arr.length];
		for (int i = 0; i < arr.length; i++) {
			data[i] = arr[i];
		}
		
		// 对于有n个元素的区间, 使用数组实现线段树的话, 需要4n的空间来存储
		tree = (E[])new Object[arr.length * 4];
		buildSegmentTree(0, 0, data.length - 1);
	}
	
	/**
	 * 在treeIndex的位置, 创建表示区间[l, r]的线段树
	 * 递归算法
	 * @param treeIndex
	 * @param l
	 * @param r
	 */
	private void buildSegmentTree(int treeIndex, int l, int r){
		// 递归到底的情况
		if (l == r) {
			tree[treeIndex] = data[l];
			return;
		}
		
		int leftTreeIndex = leftChild(treeIndex);
		int rightTreeIndex = rightChild(treeIndex);
		int mid = l + (r - l) / 2;
		buildSegmentTree(leftTreeIndex, l, mid);
		buildSegmentTree(rightTreeIndex, mid + 1, r);
		// 根据具体场景自定义merge方法
		tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
	}
	
	/**
	 * 计算index节点左孩子的位置
	 * @param index
	 * @return
	 */
	private int leftChild(int index){
		return 2 * index + 1;
	}
	
	/**
	 * 计算index节点左孩子的位置
	 * @param index
	 * @return
	 */
	private int rightChild(int index){
		return 2 * index + 2;
	}
	
	/**
	 * 查询QueryL~QueryR之间的区间
	 * @param queryL
	 * @param queryR
	 * @return
	 */
	public E query(int queryL, int queryR){
		if (queryL < 0 || queryL >=data.hashCode() || 
			queryR < 0 || queryR >= data.length ||
			queryL > queryR) {
			throw new IllegalArgumentException("无效的区间[" + queryL + ", " + queryR + "]");
		}
		
		return query(0, 0, data.length - 1 , queryL, queryR);
	}
	
	/**
	 * 从treeIndex节点开始, 在l~r的范围内查找QueryL~QueryR之间的区间
	 * @param treeIndex
	 * @param queryL
	 * @param queryR
	 * @return
	 */
	private E query(int treeIndex, int l, int r, int queryL, int queryR){
		// 递归终结条件, 左右边界相同时, 表示找到了对应的区间
		if (l == queryL && r == queryR) {
			return tree[treeIndex];
		}
		
		int mid = l + (r - l) / 2;
		int leftTreeIndex = leftChild(treeIndex);
		int rightTreeIndex = rightChild(treeIndex);
		// 要查找的区间右边界小于mid时, 说明只需要到左子树进行查找即可
		if (queryR <= mid) {
			return query(leftTreeIndex, l, mid, queryL, queryR);
		// 要查找的区间左边界大于mid时, 说明只需要到右子树进行查找即可
		} else if (queryL > mid){
			return query(rightTreeIndex, mid + 1, r, queryL, queryR);
		// queryL <=mid < queryR这种情况需要对左右子树分别进行查找
		} else { // queryL <=mid < queryR
			return merger.merge(query(leftTreeIndex, l, mid, queryL, mid), query(rightTreeIndex, mid + 1, r, mid + 1, queryR));
		}
	}
	
	/**
	 * 更新位置index的值
	 * @param index
	 * @param value
	 */
	public void set(int index, E value){
		 if(index < 0 || index >= data.length)
	            throw new IllegalArgumentException("下标越界");

	        data[index] = value;
	        set(0, 0, data.length - 1, index, value);
	}
	
	/**
	 * 在以treeIndex为根的线段树中更新index的值为e
	 * 递归算法
	 * @param treeIndex
	 * @param l
	 * @param r
	 * @param index
	 */
	private void set(int treeIndex, int l, int r, int index, E value){
		// 递归终止条件
		if (l == r) {
			tree[treeIndex] = value;
			return;
		}
		
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        int mid = l + (r - l) / 2;
		if (index <= mid) {
			set(leftTreeIndex, l, mid, index, value);
		} else { // index > mid
			set(rightTreeIndex, mid + 1, r, index, value);
		}
        
		// 因为所有包含index区间的值都要更新, 所以需要对treeIndex节点进行一次merge操作
		tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
	}
	
	// size
	public int getSize(){
		return data.length;
	}
	
	// get
	public E get(int index){
		if (index < 0 || index >=data.length) {
			throw new IllegalArgumentException("无效的位置 : " + index);
		}
		return data[index];
	}

	@Override
	public String toString() {
		StringBuilder res = new StringBuilder(); 
		res.append("SegmentTree [");
		for (int i = 0; i < tree.length; i++) {
			if (tree[i] != null) {
				res.append(tree[i]);
			} else {
				res.append("null");
			}
			
			if (i != tree.length -1) {
				res.append(", ");
			}
		}
		res.append("]");
		return res.toString();
	}
	
	

	
}

使用的merger融合器代码如下 :

package tree.segment;

/**
 * 融合器
 * 用于将两个元素融合成一个元素
 * 配合线段树的合并操作使用
 * @FunctionalInterface这个注解是jdk8中函数式接口声明, 加不加不影响
 * @author 七夜雪
 *
 */
@FunctionalInterface
public interface Merger<E> {
	E merge(E a, E b);
}

使用Junit进行简单测试的代码如下 :

package tree.segment;

import org.junit.Test ;

public class SegmentTreeTest {
	
	@Test
	public void testBuild(){
		Integer[] nums = {2, 3, 4, -1 , -2, 3};
		// jdk8的lambda表达式写法
		SegmentTree<Integer> segment = new SegmentTree<>(nums, (a, b) -> a + b);
		System.out.println(segment) ;
		System.out.println(segment.query(1, 3)) ;
	}
	
	@Test
	public void testBuildSet(){
		Integer[] nums = {2, 3, 4, -1 , -2, 3};
		// jdk8的lambda表达式写法
		SegmentTree<Integer> segment = new SegmentTree<>(nums, (a, b) -> a + b);
		System.out.println(segment) ;
		segment.set(3, 1);
		segment.set(4, 2);
		System.out.println(segment) ;
	}
		
}

猜你喜欢

转载自blog.csdn.net/love905661433/article/details/83006902