【模板】线段树

线段树是一种神奇的数据结构呀~~~
由于线段树在OI中的运用十分灵活,没有固定性的模板,这里就给出能够完成以下操作的线段树:
1.给一段区间加上一个值
2.询问一个区间内数值的总和
很水对吧,所以这才叫模板。。。。
l a z y t a g 是个好东西,要养成写 l a z y t a g 的好习惯。

先码上一个链表版本的线段树:

链表线段树很简单(不是假的),像我这样的蒟蒻就是因为静态数组线段树写炸了才写链表线段树的。

外模板

struct SegTree {
    int l, r;   //区间左右端点
    Segtree *lc, *rc;  //指向左子树和右子树的指针
    long long sum, add;  //用于维护区间和

};

几个构造函数(方便函数使用)

    SegTree(int l, int r, SegTree *lc, SegTree *rc)
        : l(l), r(r), lc(lc), rc(rc), sum(lc->sum + rc->sum), add(0) {} //非子结点的构造函数,自带pushup()
    SegTree(int l, int r, SegTree *lc, SegTree *rc, long long sum)
        : l(l), r(r), lc(lc), rc(rc), sum(sum), add(0) {} //子结点构造函数

build函数

    static SegTree *build(int l, int r) { //静态成员函数
        if (l > r) return NULL;
        else if (l == r) return new SegTree(l, r, NULL, NULL, a[l]); //叶子节点构造
        else {
            int mid = (l + r) >> 1;
            return new SegTree(l, r, build(l, mid), build(mid + 1, r)); //非叶子节点递归构造
        }
    }

updata

    void update(int l, int r, long long rhs) {
        if (l > this->r || r < this->l) return; //若当前区间和要处理的区间没有交集,就return
        else if (l <= this->l && this->r <= r) cover(rhs); //当前区间和要处理的区间重合,直接覆盖标记
        else {
            pushdown();
            lc->update(l, r, rhs); //递归
            rc->update(l, r, rhs);
            pushup();
        }
        return;
    }

query

    long long query(int l, int r) {
        if (l > this->r || r < this->l) return 0;
        if (l <= this->l && this->r <= r) return sum;
        pushdown();
        return lc->query(l, r) + rc->query(l, r); //递归统计
    }

cover覆盖标记

    void cover(long long rhs) {
        add += rhs;
        sum += rhs * (r - l + 1); 
    }

pushdown标记下传

    void pushdown() {
        lc->cover(add);
        rc->cover(add);
        add = 0;
    }

pushup

    void pushup() {
        sum = lc->sum + rc->sum; 
    }

完整代码

struct SegTree {
    int l, r;
    SegTree *lc, *rc;
    long long sum, add;
    SegTree(int l, int r, SegTree *lc, SegTree *rc)
        : l(l), r(r), lc(lc), rc(rc), sum(lc->sum + rc->sum), add(0) {}
    SegTree(int l, int r, SegTree *lc, SegTree *rc, long long sum)
        : l(l), r(r), lc(lc), rc(rc), sum(sum), add(0) {}
    void cover(long long rhs) {
        add += rhs;
        sum += rhs * (r - l + 1); 
    }
    void pushup() {
        sum = lc->sum + rc->sum; 
    }
    void pushdown() {
        lc->cover(add);
        rc->cover(add);
        add = 0;
    }
    void update(int l, int r, long long rhs) {
        if (l > this->r || r < this->l) return;
        else if (l <= this->l && this->r <= r) cover(rhs);
        else {
            pushdown();
            lc->update(l, r, rhs);
            rc->update(l, r, rhs);
            pushup();
        }
        return;
    }
    long long query(int l, int r) {
        if (l > this->r || r < this->l) return 0;
        if (l <= this->l && this->r <= r) return sum;
        pushdown();
        return lc->query(l, r) + rc->query(l, r);
    }
    static SegTree *build(int l, int r) {
        if (l > r) return NULL;
        else if (l == r) return new SegTree(l, r, NULL, NULL, a[l]);
        else {
            int mid = (l + r) >> 1;
            return new SegTree(l, r, build(l, mid), build(mid + 1, r));
        }
    }
};

相似的,我直接给出不带注释的数组版本的线段树(左移2位千万不要忘!!!)。

struct SegTree {
    static const int MAXSIZE = 100000 + 5;
    long long sum[MAXSIZE << 2], add[MAXSIZE << 2];
    void cover(int o, int l, int r, long long rhs) {
        add[o] += rhs;
        sum[o] += rhs * (r - l + 1);
    }
    void pushdown(int o, int l, int r) {
        int mid = (l + r) >> 1;
        cover(o << 1, l, mid, add[o]);
        cover(o << 1 | 1, mid + 1, r, add[o]);
        add[o] = 0;
    }
    void pushup(int o) {
        sum[o] = sum[o << 1] + sum[o << 1 | 1];
    }
    void update(int o, int l, int r, int L, int R, long long rhs) {
        if (L > r || R < l) return;
        else if (L <= l && r <= R) cover(o, l, r, rhs);
        else {
            pushdown(o, l, r);
            int mid = (l + r) >> 1;
            update(o << 1, l, mid, L, R, rhs);
            update(o << 1 | 1, mid + 1, r, L, R, rhs);
            pushup(o);
        }
    }
    long long query(int o, int l, int r, int L, int R) {
        if (L > r || R < l) return 0;
        if (L <= l && r <= R) return sum[o];
        pushdown(o, l, r);
        int mid = (l + r) >> 1;
        return query(o << 1, l, mid, L, R) + query(o << 1 | 1, mid + 1, r, L, R);
    }
    void build(int o, int l, int r) {
        if (l > r) return;
        else if (l == r) sum[o] = a[l], add[o] = 0;
        else {
            int mid = (l + r) >> 1;
            build(o << 1, l, mid);
            build(o << 1 | 1, mid + 1, r);
            pushup(o);
        }  
    }
};

猜你喜欢

转载自blog.csdn.net/diogenes_/article/details/80199532
今日推荐