题解 P3372 【【模板】线段树 1】

线段树模板题

所以,我偏不用线段树

奇了怪了

主要思路:平衡树——Splay

Splay是可以很好的维护区间的。

我这里主要讲如何用Splay维护区间。

我们知道Splay是严格按照中序遍历的顺序的,用rotate操作并不会改变这种性质,所以我们我们可以考虑一下一棵二叉树的中序遍历的特点。

如果我们把左端点splay到树根,把右端点splay到树根的右儿子位置,我们再做下中序遍历,,,(可以自行脑补)

是不是根的右儿子的左子树的信息就是这段区间的信息?

所以我们用Splay维护区间时我们是提取区间。

完整代码:

(福利:其中还有插入节点,删除节点等操作哦QwQ,附带节点垃圾桶,可回收旧结点)

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <vector>
using namespace std;
#define go(i, j, n, k) for (int i = j; i <= n; i += k)
#define fo(i, j, n, k) for (int i = j; i >= n; i -= k)
#define rep(i, x) for (int i = h[x]; i; i = e[i].nxt)
#define mn 100010
#define ld long double
#define fi first
#define se second
#define inf 1 << 30
#define ll long long
#define root 1, n, 1
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
inline ll read()
{
    ll x = 0, f = 1;
    char ch = getchar();
    while (ch > '9' || ch < '0')
    {
        if (ch == '-')
            f = -f;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}
vector<int> ljt;
struct tree
{
    int ch[2], sze, fa;
    ll w, sum, col;
    tree(int _sze = 0, int _fa = 0, int _w = 0, int _sum = 0, int _col = 0)
        : sze(_sze), fa(_fa), w(_w), sum(_sum), col(_col) { ch[1] = ch[0] = 0; }
} z[mn];
int n, m, tot, a[mn];
inline int newnode(int v = 0)
{
    int rt;
    if (ljt.empty())
        rt = ++tot;
    else
        rt = ljt.back(), ljt.pop_back();
    z[rt].fa = z[rt].ch[0] = z[rt].ch[1] = 0;
    z[rt].sze = 1;
    z[rt].w = z[rt].sum = v;
    return rt;
}
inline void deletenode(int rt)
{
    z[rt].fa = z[rt].ch[0] = z[rt].ch[1] = z[rt].w = z[rt].sum = 0;
    z[rt].sze = 1;
    ljt.push_back(rt);
}
inline void update(int rt)
{
    z[rt].sum = z[rt].w, z[rt].sze = 1;
    if(z[rt].ch[0])
        z[rt].sum += z[z[rt].ch[0]].sum, z[rt].sze += z[z[rt].ch[0]].sze;
    if(z[rt].ch[1])
        z[rt].sum += z[z[rt].ch[1]].sum, z[rt].sze += z[z[rt].ch[1]].sze;
}
inline void push_col(int rt)
{
    if(z[rt].col)
    {
        z[z[rt].ch[0]].col += z[rt].col;
        z[z[rt].ch[1]].col += z[rt].col;
        z[z[rt].ch[0]].sum += z[z[rt].ch[0]].sze * z[rt].col;
        z[z[rt].ch[1]].sum += z[z[rt].ch[1]].sze * z[rt].col;
        z[z[rt].ch[0]].w += z[rt].col;
        z[z[rt].ch[1]].w += z[rt].col;
        z[rt].col = 0;
    }
}
inline int iden(int rt)
{
    return z[z[rt].fa].ch[0] == rt ? 0 : 1;
}
inline void conn(int x, int y, int son)
{
    z[x].fa = y;
    z[y].ch[son] = x;
}
inline void rotate(int x)//敲好记的rotate函数!
{
    int y = z[x].fa;
    int moot = z[y].fa;
    int yson = iden(x);
    int mootson = iden(y);
    int B = z[x].ch[yson ^ 1];
    conn(B, y, yson), conn(y, x, yson ^ 1), conn(x, moot, mootson);
    update(y), update(x);
}
inline void splay(int x, int &k)//传址要注意
{
    if (x == k)
        return;
    int p = z[k].fa;
    while (z[x].fa != p)
    {
        push_col(x);
        int y = z[x].fa;
        if (z[y].fa != p)
            rotate(iden(y) ^ iden(x) ? x : y);
        rotate(x);
    }
    k = x;
}
inline int findkth(int rt, int k)
{
    while (233)
    {
        push_col(rt);
        if (z[rt].ch[0] && k <= z[z[rt].ch[0]].sze)
            rt = z[rt].ch[0];
        else
        {
            if (z[rt].ch[0])
                k -= z[z[rt].ch[0]].sze;
            if (!--k)
                return rt;
            rt = z[rt].ch[1];
        }
    }
}
inline void insert(int &rt, int p, int v)//传址要注意
{
    int x = findkth(rt, p);
    splay(x, rt);
    int y = findkth(rt, p + 1);
    int ooo = z[rt].ch[1];
    splay(y, ooo);
    z[y].ch[0] = newnode(v);
    z[z[y].ch[0]].fa = y;
    update(z[y].ch[0]), update(y), update(x);
}
inline void erase(int &rt, int p)//传址要注意
{
    int y = findkth(rt, p);
    splay(y, rt);
    int x = findkth(rt, p + 1);
    int ooo = z[rt].ch[1];
    splay(x, ooo);
    int oo = z[x].ch[1];
    z[oo].fa = y;
    z[y].ch[1] = oo;
    deletenode(x);
    update(y);
}
inline int getRange(int &rt, int l, int r)//传址要注意
{
    int x = findkth(rt, l);
    splay(x, rt);
    int y = findkth(rt, r + 2);
    int ooo = z[rt].ch[1];
    splay(y, ooo);
    return z[y].ch[0];
}
inline void modify(int &rt, int l, int r, ll v)//传址要注意
{
    int x = getRange(rt, l, r);
    z[x].col += v;
    z[x].w += v;
    z[x].sum += z[x].sze * v;
    update(z[rt].ch[1]);
    update(rt);
}
inline ll query(int &rt, int l, int r)//传址要注意
{
    int x = getRange(rt, l, r);
    return z[x].sum;
}
inline void build(int rt, int l, int r)
{
    int m = (l + r) >> 1;
    z[rt].w = a[m];
    if (l <= m - 1)
    {
        z[rt].ch[0] = newnode();
        z[z[rt].ch[0]].fa = rt;
        build(z[rt].ch[0], l, m - 1);
    }
    if (m + 1 <= r)
    {
        z[rt].ch[1] = newnode();
        z[z[rt].ch[1]].fa = rt;
        build(z[rt].ch[1], m + 1, r);
    }
    update(rt);
}
inline void debug(int rt)//debug专用,利用中序遍历
{
    //if(!z[rt].ch[0] && !z[rt].ch[1])
    //  return;
    if (rt == 0)
        return;
    debug(z[rt].ch[0]);
    printf("%d %d %d\n", z[rt].w, z[rt].sum, z[rt].sze);
    debug(z[rt].ch[1]);
}
int main()
{
    n = read();
    m = read();
    go(i, 1, n, 1) a[i] = read();
    int rot = ++tot;
    build(rot, 0, n + 1);
    //debug(rot);
    //cout << query(rot, 1, n) << "\n";
    go(i, 1, m, 1)
    {
        int s = read(), x = read(), y = read();
        if (s == 1)
        {
            int v = read();
            modify(rot, x, y, v);
        }
        else
        {
            cout << query(rot, x, y) << "\n";
        }
    }
    return 0;
}

本蒟蒻的第一篇Splay的题解,若有不恰当的地方请大家指教

猜你喜欢

转载自www.cnblogs.com/yizimi/p/10056279.html