【ACWing】955. 维护数列

题目地址:

https://www.acwing.com/problem/content/957/

请写一个程序,要求维护一个数列,支持以下 6 6 6种操作:(请注意,格式栏 中的下划线_表示实际输入文件中的空格)
在这里插入图片描述

输入格式:
1 1 1行包含两个数 N N N M M M N N N表示初始时数列中数的个数, M M M表示要进行的操作数目。
2 2 2行包含 N N N个数字,描述初始时的数列。
以下 M M M行,每行一条命令,格式参见问题描述中的表格。

输出格式:
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。

数据范围:
你可以认为在任何时刻,数列中至少有 1 1 1个数。输入数据一定是正确的,即指定位置的数在数列中一定存在。
50 % 50\% 50%的数据中,任何时刻数列中最多含有 30000 30000 30000个数; 100 % 100\% 100%的数据中,任何时刻数列中最多含有 500000 500000 500000个数。
100 % 100\% 100%的数据中,任何时刻数列中任何一个数字均在 [ − 1000 , 1000 ] [−1000,1000] [1000,1000]内。
100 % 100\% 100%的数据中, M ≤ 20000 M≤20000 M20000,插入的数字总数不超过 4000000 4000000 4000000个,输入文件大小不超过 20 20 20MBytes。

法1:Splay。关于维护序列的问题,最常用的数据结构之一就是Splay树(另一个非常常用的数据结构是FHQ Treap)。这道题涉及到插入一个区间,删除一个区间,修改一个区间,翻转一个区间以及求最大子区间和这些操作,Splay都可以做,插入和删除这两个操作Splay可以直接做(由于要找到插入、删除的位置,每个节点需要存其为根的子树的节点总个数,并且为了可以实现在头部、尾部进行插入和删除,我们需要开两个哨兵节点,即第 0 0 0个数和第 N + 1 N+1 N+1个数,这两个节点不参与计算),修改和翻转需要开两个懒标记,求最大子区间和可以在树节点里增加当前子树所对应的序列的最大前缀和 l s ls ls、最大后缀和 r s rs rs、最大子区间和 m s ms ms以及区间和 s u m sum sum这几个变量,然后利用分治的思想实现。具体实现方式如下:
1、在第 p p p个数字之后,插入 t t t个数。先找到第 p + 1 p+1 p+1的位置(注意要考虑哨兵节点),然后再找到第 p + 2 p+2 p+2的位置,将第 p + 1 p+1 p+1的数的节点Splay到树根,将第 p + 2 p+2 p+2的数的节点Splay到树根右边,接下来只需要在它左边插入区间即可。注意在建树的时候,我们需要建出尽量平衡的树,可以采用递归的方式,取序列中点为树根,然后递归建立左右子树。
2、从第 p p p个数字开始,删除 t t t个数。先找到第 p p p的位置,然后找到第 p + t + 1 p+t+1 p+t+1的位置,将前者Splay到树根,将后者Splay到树根右边,这样其左子树就是要删的区间,直接删去即可。注意这里虽然题目中说插入的总数字个数不超过 4 e 6 4e6 4e6,但是如果频繁的插入删除的话,会新开很多节点,所以我们要考虑将删掉的节点回收利用。这里可以开一个栈,先将所有可以用的节点push进去,用一个就pop出一个用之;如果某个节点不用了,则回收进栈。也就是说在删除子树的时候,需要DFS一遍将每个节点回收到栈里。
3、从第 p p p个数字开始,将 t t t个数修改为 c c c。和上面类似,Splay两次找到要改的区间。直接将其打上懒标记,同时更新其根的信息(注意,这里我们规定树节点的信息是懒标记生效后的值,因为我们要保证后面pushup的时候父亲用的是儿子节点的正确的信息来更新的。所以这里更新根的信息的话,可以先更新 s u m sum sum。假设区间长度是 l l l,这里区间长度其实就是子树节点个数,而 m s , l s ms,ls ms,ls r s rs rs要取决于 c c c的正负。如果 c > 0 c>0 c>0,那显然最大前缀、后缀和子区间和都是 c l cl cl;否则的话,左子树的最大前后缀和都是 0 0 0,右子树一样,最大子区间和就是当前节点自己)。
4、从第 p p p个数字开始,翻转 t t t个数。和上面类似,Splay两次找到要改的区间。直接将其打上懒标记,同时更新根的信息,这里需要将根的 l s ls ls r s rs rs调换,并且将左右子树调换(可以这么理解,在pushdown之前,懒标记的唯一作用就是标记一下左右儿子需要翻转,但当前还未翻转。显然pushdown之前要先对换左右儿子)。
5、从第 p p p个数字开始,求 t t t个数的和。和上面类似Splay两次找到要查询的区间,直接取出和即可。
6、直接查询树根的 m s ms ms

接下来考虑pushup和pushdown操作。
pushup比较简单,其节点个数等于左右子树节点个数之和加 1 1 1 s u m sum sum等于左右子树 s u m sum sum加自己的值, l s ls ls等于左子树的 l s ls ls与左边的 s u m sum sum加右子树 l s ls ls两者更大者, r s rs rs类似, m s ms ms等于左子树 m s ms ms、右子树 m s ms ms与左子树 r s rs rs加当前节点加右子树 l s ls ls三者的最大者。
pushdown略微麻烦些,如果当前节点有MAKE-SAME的懒标记,那么可以直接忽略其REVERSE懒标记,然后更新左右子树信息,更新方式参考上面的操作 3 3 3;如果没有MAKE-SAME的懒标记但有REVERSE懒标记,则直接更新左右子树信息,更新方式参考上面的操作 4 4 4

此外,由于我们不希望哨兵以及null节点影响答案的正确性,我们将它们的 m s ms ms都取为 − ∞ -\infty 。由于查询第 k k k个数的时候,每次向下走一步之前都要pushdown一下,而且每个操作都有”查询第 k k k个数“这个操作(除了操作 6 6 6),所以当找到第 k k k个数的时候,其与其所有祖宗的信息都已经正确了,并且懒标记都已经被清掉,所以将其Splay到树根不会影响答案的正确性。

代码如下:

#include <iostream>
#include <cstring>
using namespace std;

const int N = 4e6 + 10, INF = 1e9;
int n, m;
struct Node {
    
    
    int s[2], p, v, sz;
    int sum, ls, rs, ms;
    bool rev, same;

    void init(int _v, int _p) {
    
    
        s[0] = s[1] = 0, v = _v, p = _p;
        rev = same = 0;
        sz = 1, sum = ms = v;
        ls = rs = max(0, v);
    }
} tr[N];
// nodes是回收节点的栈,tt是栈顶
int root, nodes[N], tt;
int w[N];

void pushup(int x) {
    
    
    auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
    u.sz = l.sz + r.sz + 1;
    u.sum = l.sum + r.sum + u.v;
    u.ls = max(l.ls, l.sum + u.v + r.ls);
    u.rs = max(r.rs, r.sum + u.v + l.rs);
    u.ms = max(max(l.ms, r.ms), l.rs + u.v + r.ls);
}

void pushdown(int x) {
    
    
    auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
    if (u.same) {
    
    
        u.same = u.rev = 0;
        if (u.s[0]) l.same = 1, l.v = u.v, l.sum = l.v * l.sz;
        if (u.s[1]) r.same = 1, r.v = u.v, r.sum = r.v * r.sz;
        if (u.v > 0) {
    
    
            if (u.s[0]) l.ms = l.ls = l.rs = l.sum;
            if (u.s[1]) r.ms = r.ls = r.rs = r.sum;
        } else {
    
    
            if (u.s[0]) l.ms = l.v, l.ls = l.rs = 0;
            if (u.s[1]) r.ms = r.v, r.ls = r.rs = 0;
        }
    } else if (u.rev) {
    
    
        u.rev = 0, l.rev ^= 1, r.rev ^= 1;
        swap(l.ls, l.rs), swap(r.ls, r.rs);
        swap(l.s[0], l.s[1]), swap(r.s[0], r.s[1]);
    }
}

void rotate(int x) {
    
    
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x, int k) {
    
    
    while (tr[x].p != k) {
    
    
        int y = tr[x].p, z = tr[y].p;
        if (z != k) 
            if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        rotate(x); 
    }

    if (!k) root = x;
}

int get_k(int k) {
    
    
    int u = root;
    while (u) {
    
    
        pushdown(u);
        if (k <= tr[tr[u].s[0]].sz) u = tr[u].s[0];
        else if (k > tr[tr[u].s[0]].sz + 1)
            k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
        else return u;
    }

    return -1;
}

int build(int l, int r, int p) {
    
    
    int mid = l + r >> 1;
    int u = nodes[tt--];
    tr[u].init(w[mid], p);
    if (l < mid) tr[u].s[0] = build(l, mid - 1, u);
    if (r > mid) tr[u].s[1] = build(mid + 1, r, u);
    pushup(u);
    return u;
}

void dfs(int u) {
    
    
    if (!u) return;
    dfs(tr[u].s[0]);
    dfs(tr[u].s[1]);
    // 将当前节点放入栈,以供下次重复使用
    nodes[++tt] = u;
}

int main() {
    
    
    for (int i = 1; i < N; i++) nodes[++tt] = i;
    scanf("%d%d", &n, &m);
    tr[0].ms = w[0] = w[n + 1] = -INF;
    for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
    root = build(0, n + 1, 0);

    char op[20];
    while (m--) {
    
    
        scanf("%s", op);
        if (!strcmp(op, "INSERT")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            for (int i = 1; i <= tot; i++) scanf("%d", &w[i]);
            int l = get_k(posi + 1), r = get_k(posi + 2);
            splay(l, 0), splay(r, l);
            int u = build(1, tot, r);
            tr[r].s[0] = u;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "DELETE")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            int l = get_k(posi), r = get_k(posi + tot + 1);
            splay(l, 0), splay(r, l);
            dfs(tr[r].s[0]);
            tr[r].s[0] = 0;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "MAKE-SAME")) {
    
    
            int posi, tot, c;
            scanf("%d%d%d", &posi, &tot, &c);
            int l = get_k(posi), r = get_k(posi + tot + 1);
            splay(l, 0), splay(r, l);
            auto &son = tr[tr[r].s[0]];
            son.same = 1, son.v = c, son.sum = c * son.sz;
            if (c > 0) son.ms = son.ls = son.rs = son.sum;
            else son.ms = c, son.ls = son.rs = 0;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "REVERSE")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            int l = get_k(posi), r = get_k(posi + tot + 1);
            splay(l, 0), splay(r, l);
            auto &son = tr[tr[r].s[0]];
            son.rev ^= 1;
            swap(son.s[0], son.s[1]);
            swap(son.ls, son.rs);
            pushup(r), pushup(l);
        } else if (!strcmp(op, "GET-SUM")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            int l = get_k(posi), r = get_k(posi + tot + 1);
            splay(l, 0), splay(r, l);
            auto &son = tr[tr[r].s[0]];
            printf("%d\n", son.sum);
        } else printf("%d\n", tr[root].ms);
    }

    return 0;
}

操作 1 , 2 1,2 1,2时间复杂度 O ( log ⁡ n + t ) O(\log n+t) O(logn+t),操作 3 , 4 , 5 3,4,5 3,4,5时间 O ( log ⁡ n ) O(\log n) O(logn),操作 6 6 6时间 O ( 1 ) O(1) O(1),空间 O ( n ) O(n) O(n)

法2:FHQ Treap。思路和上面完全类似,分裂操作可以分裂出要修改的区间,接下来的修改方式与Splay一样。与Splay略微不一样的地方在于,FHQ Treap是每次分裂的时候都pushdown一下。代码如下:

#include <iostream>
#include <cstring>
using namespace std;

const int N = 4e6 + 10, INF = 1e9;
int n, m;
struct Node {
    
    
    int l, r, v, sz, rnd;
    int sum, ls, rs, ms;
    bool rev, same;

    void init(int _v) {
    
    
        l = r = 0;
        v = _v;
        rev = same = 0;
        sz = 1, sum = ms = v;
        ls = rs = max(0, v);
        rnd = rand();
    }
} tr[N];
int root, nodes[N], tt;
int x, y, z;
int w[N];

void pushup(int x) {
    
    
    auto &u = tr[x], &l = tr[tr[x].l], &r = tr[tr[x].r];
    u.sz = l.sz + r.sz + 1;
    u.sum = l.sum + r.sum + u.v;
    u.ls = max(l.ls, l.sum + u.v + r.ls);
    u.rs = max(r.rs, r.sum + u.v + l.rs);
    u.ms = max(max(l.ms, r.ms), l.rs + u.v + r.ls);
}

void pushdown(int x) {
    
    
    auto &u = tr[x], &l = tr[tr[x].l], &r = tr[tr[x].r];
    if (u.same) {
    
    
        u.same = u.rev = 0;
        if (u.l) l.same = 1, l.v = u.v, l.sum = l.v * l.sz;
        if (u.r) r.same = 1, r.v = u.v, r.sum = r.v * r.sz;
        if (u.v > 0) {
    
    
            if (u.l) l.ms = l.ls = l.rs = l.sum;
            if (u.r) r.ms = r.ls = r.rs = r.sum;
        } else {
    
    
            if (u.l) l.ms = l.v, l.ls = l.rs = 0;
            if (u.r) r.ms = r.v, r.ls = r.rs = 0;
        }
    } else if (u.rev) {
    
    
        u.rev = 0, l.rev ^= 1, r.rev ^= 1;
        swap(l.ls, l.rs), swap(r.ls, r.rs);
        swap(l.l, l.r), swap(r.l, r.r);
    }
}

int build(int l, int r) {
    
    
    int mid = l + r >> 1;
    int u = nodes[tt--];
    tr[u].init(w[mid]);
    if (l < mid) tr[u].l = build(l, mid - 1);
    if (r > mid) tr[u].r = build(mid + 1, r);
    pushup(u);
    return u;
}

void split(int u, int sz, int &x, int &y) {
    
    
    if (!u) x = y = 0;
    else {
    
    
        pushdown(u);
        if (tr[tr[u].l].sz < sz) {
    
    
            x = u;
            split(tr[u].r, sz - tr[tr[u].l].sz - 1, tr[u].r, y);
        } else {
    
    
            y = u;
            split(tr[u].l, sz, x, tr[u].l);
        }
        pushup(u);
    }
}

int merge(int x, int y) {
    
    
    if (!x || !y) return x | y;
    if (tr[x].rnd > tr[y].rnd) {
    
    
        pushdown(x);
        tr[x].r = merge(tr[x].r, y);
        pushup(x);
        return x;
    } else {
    
    
        pushdown(y);
        tr[y].l = merge(x, tr[y].l);
        pushup(y);
        return y;
    }
}

void dfs(int u) {
    
    
    if (!u) return;
    dfs(tr[u].l);
    dfs(tr[u].r);
    nodes[++tt] = u;
}

int main() {
    
    
    for (int i = 1; i < N; i++) nodes[++tt] = i;
    tr[0].ms = -INF;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
    root = build(1, n);

    char op[20];
    while (m--) {
    
    
        scanf("%s", op);
        if (!strcmp(op, "INSERT")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            for (int i = 1; i <= tot; i++) scanf("%d", &w[i]);
            // 分裂开,建树,再合并回来
            split(root, posi, x, y);
            int u = build(1, tot);
            root = merge(merge(x, u), y);
        } else if (!strcmp(op, "DELETE")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            // 分裂开,直接删掉中间的区间,再合并回来
            split(root, posi - 1, x, y);
            split(y, tot, y, z);
            dfs(y);
            root = merge(x, z);
        } else if (!strcmp(op, "MAKE-SAME")) {
    
    
            int posi, tot, c;
            scanf("%d%d%d", &posi, &tot, &c);
            split(root, posi - 1, x, y);
            split(y, tot, y, z);
            auto &u = tr[y];
            u.same = 1, u.v = c, u.sum = c * u.sz;
            if (c > 0) u.ms = u.ls = u.rs = u.sum;
            else u.ms = c, u.ls = u.rs = 0;
            root = merge(merge(x, y), z);
        } else if (!strcmp(op, "REVERSE")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            split(root, posi - 1, x, y);
            split(y, tot, y, z);
            auto &u = tr[y];
            u.rev ^= 1;
            swap(u.l, u.r);
            swap(u.ls, u.rs);
            root = merge(merge(x, y), z);
        } else if (!strcmp(op, "GET-SUM")) {
    
    
            int posi, tot;
            scanf("%d%d", &posi, &tot);
            split(root, posi - 1, x, y);
            split(y, tot, y, z);
            printf("%d\n", tr[y].sum);
            root = merge(merge(x, y), z);
        } else printf("%d\n", tr[root].ms);
    }

    return 0;
}

所有操作时间复杂度与上面相同。

猜你喜欢

转载自blog.csdn.net/qq_46105170/article/details/121540816
今日推荐