题意
给一颗树,每个点有点权,求树上一条路径 u 1 , u 2 , . . . u k u_1,u_2,...u_k u1,u2,...uk获得点权数组 a u 1 , a u 2 , . . . a u k a_{u_1},a_{u_2},...a_{u_k} au1,au2,...auk。要使得它的前缀和的前缀和最大,求这个最大值。
解题思路
树上路径的问题会想到用点分治去解决,以点分治的想法,如何把当前一条路径和现存的路径信息合并。
考虑两个数组 a 1 , a 2 . . . a n 和 b 1 , b 2 , . . . b m 。 {a_1,a_2...a_n}和{b_1,b_2,...b_m}。 a1,a2...an和b1,b2,...bm。设元素和为 s u m sum sum,前缀和的和为 x s u m xsum xsum。那么连接成 a 1 , a 2 . . . a n , b 1 , b 2 , . . . b m {a_1,a_2...a_n,b_1,b_2,...b_m} a1,a2...an,b1,b2,...bm之后总的 x s u m = x s u m a + m ∗ s u m a + x s u m b xsum=xsum_a+m*sum_a+xsum_b xsum=xsuma+m∗suma+xsumb。
这是一个关于 m m m的一次函数。我们在李超树上维护 x s u m a ( 常 数 ) xsum_a(常数) xsuma(常数)和 s u m a ( 斜 率 ) sum_a(斜率) suma(斜率),然后查询的时候使用 s u m b 和 m sum_b和m sumb和m去李超树上查询 m m m点的最大值。
然后查询的时候,顺着查一遍,再反着查一遍,就可以得到所有以某个子树中信息的为查询,其他子树的信息作为线性函数存放在李超树中对应 x s u m xsum xsum。
注意当整个树的大小为2的时候,还需要特别查一下其他所有信息存在李超树里,然后根节点作为查询信息的值。
#include<vector>
#include<iostream>
#define ll long long
#define P pair<int, int>
#define Pll pair<ll, ll>
#define pb push_back
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const int maxn = 2e5 + 50;
struct LiChao_segment_tree{
inline ll f(ll k, int x, ll b){
return k*x+b;
}
ll k[maxn<<2], b[maxn<<2]; int lz[maxn<<2];
int N;
void init(int n){
N = n; lz[1] = 1; k[1] = b[1] = 0;
}
inline void down(int rt){
if(!lz[rt]) return;
lz[rt<<1] = lz[rt<<1|1] = 1;
k[rt<<1] = k[rt<<1|1] = b[rt<<1] = b[rt<<1|1] = 0;
lz[rt] = 0;//AHHHHHHH!!
return;
}
void update(int rt, int l, int r, ll ck, ll cb){
if(f(ck, mid, cb) > f(k[rt], mid, b[rt])) swap(k[rt], ck), swap(b[rt], cb);
if(l == r) return;
down(rt);
if(ck > k[rt]) update(rson,ck, cb);//要传的斜率较大,去右边
else update(lson, ck, cb);//否则去左边
}
ll qry(int rt, int l, int r, int x){
ll ans = f(k[rt], x, b[rt]);
if(l == r) return ans;
down(rt);
if(x <= mid) ans = max(ans, qry(lson, x));
else ans = max(ans, qry(rson, x));
return ans;
}
}T;
vector<int> g[maxn];
int sz[maxn], ms[maxn], vis[maxn], totsz;
ll a[maxn];
int n, rt;
ll ans = 0;
void get_rt(int u, int fa){
sz[u] = 1; ms[u] = 0;
for(int v: g[u]){
if(v == fa || vis[v]) continue;
get_rt(v, u);
sz[u] += sz[v];
ms[u] = max(ms[u], sz[v]);
}
ms[u] = max(ms[u], totsz - sz[u]);
if(ms[rt] > ms[u]) rt = u;
}
void init()
{
scanf("%d", &n);
for(int i = 1; i < n; ++i){
int u ,v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= n; ++i) scanf("%I64d", &a[i]);
totsz = n;
}
void qry(int u, int fa, ll sum, ll xsum, int d){
int f = 0; d++; sum += a[u]; xsum += sum;
for(int v: g[u]){
if(vis[v] || v == fa) continue;
f = 1;
qry(v, u, sum, xsum, d);
}
if(f) return;
ans = max(ans, xsum + T.qry(1, 1, T.N, d));
}
void update(int u, int fa, ll sum, ll xsum, int d){
d++; sum += a[u]; xsum += d*a[u];
int f = 0;
for(int v: g[u]){
if(vis[v] || v == fa) continue;
f = 1;
update(v, u, sum, xsum, d);
}
if(f) return;
T.update(1, 1, T.N, sum, xsum);
}
void cal(int u, int cursize){
T.init(cursize);
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
if(vis[v]) continue;
qry(v, u, a[u], a[u], 1);
update(v, u, 0, 0, 0);
}
ans = max(ans, a[u] + T.qry(1, 1, T.N, 1));
T.init(cursize);
for(int i = g[u].size()-1; i >= 0; --i){
int v = g[u][i];
if(vis[v]) continue;
qry(v, u, a[u], a[u], 1);
update(v, u, 0, 0, 0);
}
}
void gao(int u, int cursize)
{
vis[u] = 1;
cal(u, cursize);
for(int v: g[u]){
if(vis[v]) continue;
totsz = sz[v] < sz[u] ? sz[v] : cursize - sz[u];
rt = 0;
get_rt(v, 0);
gao(rt, totsz);
}
}
void sol(){
rt = 0; ms[rt] = n+1;
get_rt(1, 0);
gao(rt, n);
cout<<ans<<endl;
}
int main()
{
//freopen("testdata.in", "r", stdin);
init();sol();
}
/*
4
4 2
3 2
4 1
1 3 3 7
*/