【挖坑】【GSS】GSS7:树链剖分中的最大子段和

版权声明:Powered By Fighter https://blog.csdn.net/qq_30115697/article/details/88827176

Can you answer these queries?

GSS系列是spoj出品的一套数据结构好毒瘤题,主要以线段树、平衡树和树链剖分为背景,进行了一些操作的魔改,使得难度远超模板题,但对于思维有极大的提升。

所以我会选择一些在我能力范围内的题挖坑选讲,构成一个GSS系列。至于剩下那些,等我成为巨佬弄懂了再说吧。

GSS7:树链剖分中的最大子段和

原题传送门(洛谷)

本题前置芝士:

  1. GSS1:区间最大子段和。(题解:CSDN个人博客
  2. 树链剖分

题意

给定一棵树,动态修改两点路径上的点权,查询两点路径上的最大子段和。

口胡

最大子段和又双叒叕升级啦!!

这次变成了树上修改和查询。

其实树链剖分的精髓部分就在于线段树以及最后跳链的部分,那么我们就从这两个地方入手。线段树上的最大子段和并没有什么特殊的地方,直接用GSS1的做法即可。主要难点在于跳链查询。

我们会发现,跳链的时候区间不一定是连续的,那么最后对答案的合并就是一个较大的问题。再进一步思考,会发现由于查询的是最大子段和,所以合并的时候要是把左右区间搞反,就会直接挂掉。所以我们考虑在跳链的时候分类讨论。

我们设两个变量分别记录上一次跳 x x 和上一次跳 y y 的结果,都作为当前跳链的右区间,然后查询相应的重链上的结果,作为左区间,最后合并到 x x y y 在同一条重链上为止,最后分类讨论跳 x x 还是跳 y y ,在合并的时候为了合并的正确性,我们把所有跳 x x 的结果合并后的总区间翻转(实际上只交换了 l s ls r s rs ),再与 y y 的答案合并。

由于本题合并操作较多,建议写一个merge函数。

还有,应该不需要图解了吧。(作者画图画到自闭)

代码

#include <bits/stdc++.h>
#define MAX 100005
#define INF (ll)1e16
#define ll long long
#define int ll
#define lc(x) (x<<1)
#define rc(x) (x<<1|1)
#define mid ((l+r)>>1)
using namespace std;

int n, q, cnt, tot;
int head[MAX], Next[MAX*2], vet[MAX*2];
int sz[MAX], top[MAX], f[MAX], son[MAX], d[MAX], id[MAX], rk[MAX];
ll val[MAX];

void add(int x, int y) {
    cnt++;
    Next[cnt] = head[x];
    head[x] = cnt;
    vet[cnt] = y;
}

void dfs1(int x, int fa) {
    d[x] = d[fa]+1, sz[x] = 1, f[x] = fa;
    for(int i = head[x]; i; i = Next[i]) {
        int v = vet[i];
        if(v == fa) continue;
        dfs1(v, x);
        sz[x] += sz[v];
        if(sz[v] > sz[son[x]])
            son[x] = v;
    }
}

void dfs2(int x, int t) {
    top[x] = t, id[x] = ++tot, rk[tot] = x;
    if(!son[x]) return;
    dfs2(son[x], t);
    for(int i = head[x]; i; i = Next[i]) {
        int v = vet[i];
        if(v == son[x] || v == f[x]) continue;
        dfs2(v, v);
    }
}

/****Segment Tree***/
struct node {
    ll sum, ls, rs, mx;
    bool vis;
    node() {
        sum = ls = rs = mx = vis = 0;
    }
} s[MAX*4];
ll tag[MAX*4];

inline node merge(node a, node b) {
    node res;
    res.sum = a.sum+b.sum;
    res.ls = max(a.ls, a.sum+b.ls);
    res.rs = max(b.rs, b.sum+a.rs);
    res.mx = max(max(a.mx, b.mx), a.rs+b.ls);
    return res;
}

inline void push_up(int x) {
    s[x] = merge(s[lc(x)], s[rc(x)]);
}

inline void mark(int p, int l, int r, ll k){
    s[p].sum = (r-l+1)*k;
    s[p].ls = s[p].rs = s[p].mx = max(s[p].sum, 0ll);
    tag[p] = k, s[p].vis = 1;
}

inline void push_down(int p, int l, int r) {
    if(!s[p].vis) return;
    mark(lc(p), l, mid, tag[p]);
    mark(rc(p), mid+1, r, tag[p]);
    tag[p] = s[p].vis = 0;
}

void build(int p, int l, int r) {
    s[p].vis = 0;
    if(l == r) {
        s[p].sum = val[rk[l]];
        s[p].ls = s[p].rs = s[p].mx = max(s[p].sum, 0ll);
        return;
    }
    build(lc(p), l, mid);
    build(rc(p), mid+1, r);
    push_up(p);
}

void update(int p, int l, int r, int ul, int ur, ll k) {
    if(l>=ul && r<=ur) {
        mark(p, l, r, k);
        return;
    }
    push_down(p, l, r);
    if(mid >= ul) update(lc(p), l, mid, ul, ur, k);
    if(mid < ur) update(rc(p), mid+1, r, ul, ur, k);
    push_up(p);
}

node query(int p, int l, int r, int ul, int ur) {
    if(l>=ul && r<=ur) {
        return s[p];
    }
    push_down(p, l, r);
    if(mid < ul) {
        return query(rc(p), mid+1, r, ul, ur);
    } else if(mid >= ur) {
        return query(lc(p), l, mid, ul, ur);
    } else {
        node t1 = query(lc(p), l, mid, ul, ur);
        node t2 = query(rc(p), mid+1, r, ul, ur);
        return merge(t1, t2);
    }
}
/******End******/

void modify(int x, int y, ll k) {
    while(top[x] != top[y]) {
        if(d[top[x]] < d[top[y]]) {
            swap(x, y);
        }
        update(1,1,n, id[top[x]], id[x], k);
        x = f[top[x]];
    }
    if(id[x] > id[y]) swap(x, y);
    update(1,1,n, id[x], id[y], k);
}

ll get_mx(int x, int y) {
    node qx, qy, res;
    while(top[x] != top[y]) {
        if(d[top[x]] > d[top[y]]) {
            res = query(1,1,n, id[top[x]], id[x]);
            qx = merge(res, qx);
            x = f[top[x]];
        } else {
            res = query(1,1,n, id[top[y]], id[y]);
            qy = merge(res, qy);
            y = f[top[y]];
        }
    }
    if(id[x] < id[y]) {
        res = query(1,1,n, id[x], id[y]);
        qy = merge(res, qy);
    } else {
        res = query(1,1,n, id[y], id[x]);
        qx = merge(res, qx);
    }
    swap(qx.ls, qx.rs);
    res = merge(qx, qy);
    return res.mx;
}

signed main() {
    cin >> n;
    for(int i = 1; i <= n; i++) {
        scanf("%lld", &val[i]);
    }
    int x, y;
    for(int i = 1; i < n; i++) {
        scanf("%lld%lld", &x, &y);
        add(x, y);
        add(y, x);
    }
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    cin >> q;
    int t;
    ll k;
    for(int i = 1; i <= q; i++){
        scanf("%lld%lld%lld", &t, &x, &y);
        if(t == 1){
            printf("%lld\n", get_mx(x,y));
        }
        else{
            scanf("%lld", &k);
            modify(x, y, k);
        }
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_30115697/article/details/88827176