CodeForces 1303G - Sum of Prefix Sums(点分治+李超树)

题意

给一颗树,每个点有点权,求树上一条路径 u 1 , u 2 , . . . u k u_1,u_2,...u_k u1,u2,...uk获得点权数组 a u 1 , a u 2 , . . . a u k a_{u_1},a_{u_2},...a_{u_k} au1,au2,...auk。要使得它的前缀和的前缀和最大,求这个最大值。

解题思路

树上路径的问题会想到用点分治去解决,以点分治的想法,如何把当前一条路径和现存的路径信息合并。
考虑两个数组 a 1 , a 2 . . . a n 和 b 1 , b 2 , . . . b m 。 {a_1,a_2...a_n}和{b_1,b_2,...b_m}。 a1,a2...anb1,b2,...bm设元素和为 s u m sum sum,前缀和的和为 x s u m xsum xsum。那么连接成 a 1 , a 2 . . . a n , b 1 , b 2 , . . . b m {a_1,a_2...a_n,b_1,b_2,...b_m} a1,a2...an,b1,b2,...bm之后总的 x s u m = x s u m a + m ∗ s u m a + x s u m b xsum=xsum_a+m*sum_a+xsum_b xsum=xsuma+msuma+xsumb
这是一个关于 m m m的一次函数。我们在李超树上维护 x s u m a ( 常 数 ) xsum_a(常数) xsuma() s u m a ( 斜 率 ) sum_a(斜率) suma(),然后查询的时候使用 s u m b 和 m sum_b和m sumbm去李超树上查询 m m m点的最大值。
然后查询的时候,顺着查一遍,再反着查一遍,就可以得到所有以某个子树中信息的为查询,其他子树的信息作为线性函数存放在李超树中对应 x s u m xsum xsum
注意当整个树的大小为2的时候,还需要特别查一下其他所有信息存在李超树里,然后根节点作为查询信息的值。

#include<vector>
#include<iostream>
#define ll long long
#define P pair<int, int>
#define Pll pair<ll, ll>
#define pb push_back
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const int maxn = 2e5 + 50;
struct LiChao_segment_tree{
    
    
    inline ll f(ll k, int x, ll b){
    
    
        return k*x+b;
    }
    ll k[maxn<<2], b[maxn<<2]; int lz[maxn<<2];
    int N;
    void init(int n){
    
    
        N = n; lz[1] = 1; k[1] = b[1] = 0;
    }
    inline void down(int rt){
    
    
        if(!lz[rt]) return;
        lz[rt<<1] = lz[rt<<1|1] = 1;
        k[rt<<1] = k[rt<<1|1] = b[rt<<1] = b[rt<<1|1] = 0;
        lz[rt] = 0;//AHHHHHHH!!
        return;
    }
    void update(int rt, int l, int r, ll ck, ll cb){
    
    
        if(f(ck, mid, cb) > f(k[rt], mid, b[rt])) swap(k[rt], ck), swap(b[rt], cb);
        if(l == r) return;
        down(rt);
        if(ck > k[rt]) update(rson,ck, cb);//要传的斜率较大,去右边
        else update(lson, ck, cb);//否则去左边
    }
    ll qry(int rt, int l, int r, int x){
    
    
        ll ans = f(k[rt], x, b[rt]);
        if(l == r) return ans;
        down(rt);
        if(x <= mid) ans = max(ans, qry(lson, x));
        else ans = max(ans, qry(rson, x));
        return ans;
    }
}T;
vector<int> g[maxn];
int  sz[maxn], ms[maxn], vis[maxn], totsz;
ll a[maxn];
int n, rt;
ll ans = 0;
void get_rt(int u, int fa){
    
    
    sz[u] = 1; ms[u] = 0;
    for(int v: g[u]){
    
    
        if(v == fa || vis[v]) continue;
        get_rt(v, u);
        sz[u] += sz[v];
        ms[u] = max(ms[u], sz[v]);
    }
    ms[u] = max(ms[u], totsz - sz[u]);
    if(ms[rt] > ms[u]) rt = u;
}
void init()
{
    
    
    scanf("%d", &n);
    for(int i = 1; i < n; ++i){
    
    
        int u ,v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for(int i = 1; i <= n; ++i) scanf("%I64d", &a[i]);
    totsz = n;
}
void qry(int u, int fa, ll sum, ll xsum, int d){
    
    
    int f = 0; d++; sum += a[u]; xsum += sum;
    for(int v: g[u]){
    
    
        if(vis[v] || v == fa) continue;
        f = 1;
        qry(v, u, sum, xsum, d);
    }
    if(f) return;
    ans = max(ans, xsum + T.qry(1, 1, T.N, d));
}
void update(int u, int fa, ll sum, ll xsum, int d){
    
    
    d++; sum += a[u]; xsum += d*a[u];
    int f = 0;
    for(int v: g[u]){
    
    
        if(vis[v] || v == fa) continue;
        f = 1;
        update(v, u, sum, xsum, d);
    }
    if(f) return;
    T.update(1, 1, T.N, sum, xsum);
}
void cal(int u, int cursize){
    
    
    T.init(cursize);
    for(int i = 0; i < g[u].size(); ++i){
    
    
        int v = g[u][i];
        if(vis[v]) continue;
        qry(v, u, a[u], a[u], 1);
        update(v, u, 0, 0, 0);
    }
    ans = max(ans, a[u] + T.qry(1, 1, T.N, 1));
    T.init(cursize);
    for(int i = g[u].size()-1; i >= 0; --i){
    
    
        int v = g[u][i];
        if(vis[v]) continue;
        qry(v, u, a[u], a[u], 1);
        update(v, u, 0, 0, 0);
    }
}
void gao(int u, int cursize)
{
    
    
    vis[u] = 1;
    cal(u, cursize);
    for(int v: g[u]){
    
    
        if(vis[v]) continue;
        totsz = sz[v] < sz[u] ? sz[v] : cursize - sz[u];
        rt = 0;
        get_rt(v, 0);
        gao(rt, totsz);
    }
}
void sol(){
    
    

    rt = 0; ms[rt] = n+1;
    get_rt(1, 0);
    gao(rt, n);
    cout<<ans<<endl;
}
int main()
{
    
    
    //freopen("testdata.in", "r", stdin);
	init();sol();
}
/*
4
4 2
3 2
4 1
1 3 3 7
*/

猜你喜欢

转载自blog.csdn.net/qq_43202683/article/details/104354123