什么是线段树

线段树的概念

        线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。 
        对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。

线段树的应用

        线段树 segmentTree 是一个二叉树,每个结点保存数组 nums 在区间 [left, right] 的最小值、最大值或者总和等信息。

线段树的实现

        线段树可以用树也可以用数组(堆式存储)来实现。对于数组实现,假设根结点的下标为 0,如果一个结点在数组的下标为 node,那么它的左子结点下标为 node×2+1,右子结点下标为 node×2+2。

我们来看一道题

给你一个数组 nums ,请你完成两类查询。

  1. 其中一类查询要求 更新 数组 nums 下标对应的值
  2. 另一类查询要求返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的  ,其中 left <= right

建树build 函数

我们在结点 node 保存数组 nums 在区间 [left, right]的总和。

  • left = right 时,结点 node 是叶子结点,它保存的值等于 nums[left]。
  • left < right 时,结点 node 的左子结点保存区间 [left,\frac{left+right}{2}]的总和,右子结点保存区间 [\frac{left+right}{2}+1,right] 的总和,那么结点 node 保存的值等于它的两个子结点保存的值之和。

        假设 nums 的大小为 n,我们规定根结点node=0 保存区间 [0, n - 1] 的总和,然后自下而上递归地建树。

void build(int index,int left,int right,vector<int>& nums)
    {
        if(left == right)
        {
            segmentTree[index] = nums[left];
            return;
        }
        int mid = left + (right-left)/2;
        build(index*2+1,left,mid,nums);//左孩子结点
        build(index*2+2,mid+1,right,nums);//右孩子结点
        /*父节点的值为它的两个孩子结点的值的和*/
        segmentTree[index] = segmentTree[index*2+1]+segmentTree[index*2+2];
    }

单点修改 change 函数

        当我们要修改 nums[index] 的值时,我们找到对应区间[index,index] 的叶子结点,直接修改叶子结点的值为 val,并自下而上递归地更新父结点的值。

void change(int index,int val,int node,int left,int right)
    {
        if(left == right)//找到了叶子结点
        {
            segmentTree[node] = val;
            return ;
        }
        int mid = left + (right-left)/2;
        if(index <= mid)//在左孩子中寻找
        {
            change(index,val,node*2+1,left,mid);
        }
        else//在右孩子中寻找
        {
            change(index,val,node*2+2,mid+1,right);
        }
        /*更新父节点的值为它的两个子结点的值的和*/
        segmentTree[node] = segmentTree[node*2+1] + segmentTree[node*2+2];
    }

范围求和 range 函数

给定区间 [left,right] 时,我们将区间 [left,right] 拆成多个结点对应的区间。

  • 如果结点 node 对应的区间与 [left,right] 相同,可以直接返回该结点的值,即当前区间和。
  • 如果结点 node 对应的区间与 [left,right] 不同,设左子结点对应的区间的右端点为 m,那么将区间 [left,right] 沿点 m 拆成两个区间,分别计算左子结点和右子结点。

我们从根结点开始递归地拆分区间 [left,right]。

int range(int L,int R,int node,int left,int right)
    {
        if(left == L && right == R)
        {
            return segmentTree[node];
        }
        /*利用二分法来查找*/
        int mid = left+(right-left)/2;
        if(R <= mid)//说明在左孩子里
        {
            return range(L,R,node*2+1,left,mid);
        }
        else if(L > mid)//说明在右孩子里
        {
            return range(L,R,node*2+2,mid+1,right);
        }
        else{//这种情况是指 左右孩子结点中个包含一部分值
            return range(L,mid,node*2+1,left,mid) + range(mid+1,R,node*2+2,mid+1,right);
        }
    }

完整代码

class NumArray {
private:
    vector<int> segmentTree;
    int n;

    void build(int index,int left,int right,vector<int>& nums)
    {
        if(left == right)
        {
            segmentTree[index] = nums[left];
            return;
        }
        int mid = left + (right-left)/2;
        build(index*2+1,left,mid,nums);
        build(index*2+2,mid+1,right,nums);
        segmentTree[index] = segmentTree[index*2+1]+segmentTree[index*2+2];
    }

    void change(int index,int val,int node,int left,int right)
    {
        if(left == right)
        {
            segmentTree[node] = val;
            return ;
        }
        int mid = left + (right-left)/2;
        if(index <= mid)
        {
            change(index,val,node*2+1,left,mid);
        }
        else
        {
            change(index,val,node*2+2,mid+1,right);
        }
        segmentTree[node] = segmentTree[node*2+1] + segmentTree[node*2+2];
    }

    int range(int L,int R,int node,int left,int right)
    {
        if(left == L && right == R)
        {
            return segmentTree[node];
        }
        int mid = left+(right-left)/2;
        if(R <= mid)
        {
            return range(L,R,node*2+1,left,mid);
        }
        else if(L > mid)
        {
            return range(L,R,node*2+2,mid+1,right);
        }
        else{
            return range(L,mid,node*2+1,left,mid) + range(mid+1,R,node*2+2,mid+1,right);
        }
    }
public:
    NumArray(vector<int>& nums):n(nums.size()),segmentTree(nums.size()*4)
    {
        build(0,0,n-1,nums);
    }
    
    void update(int index, int val) {
        change(index,val,0,0,n-1);
    }
    
    int sumRange(int left, int right) {
        return range(left,right,0,0,n-1);
    }
};

时间复杂度 O(log n)

猜你喜欢

转载自blog.csdn.net/ThinPikachu/article/details/123955348