CF1303G Sum of Prefix Sums

点分治+李超树

因为题目要求的是树上所有路径,所以用点分治维护

因为在点分治的过程中相当于将树上经过当前$root$的一条路径分成了两段

那么先考虑如何计算两个数组合并后的答案

记数组$a$,$b$,求得是将$b$数组接到$a$数组的答案

其$a$,$b$的sum of prefix sums分别为$sa$,$sb$,其中$a$数组所有元素的和为$sum$,$b$数组长度为$l$

然后整合一下原来计算的式

其实对于一个数组$P$的sum of prefix sums就是

$n*p_{1}+(n-1)*p_{2}+(n-2)*p_{3}+...+2*p_{n-1}+1*p_{n}$

照着这个式子推出来,将$b$数组接到$a$数组的答案是

$sa+sb+sum*l$

然后这里可以将$sa$看做截距,$sum$看做斜率,$l$为$x$坐标,最终答案为$y$坐标

求的就是每一个$sa$为截距,$sum$为斜率的线段在某一点的最大取值

那么用李超树维护即可

要注意对于树上两个节点$u$,$v$

$u$到$v$的答案和$v$到$u$的答案是不一样的

所以要在合并子树的时候要正着扫一遍,反着扫一遍

还有如果是只有一颗子树需要特判

  1 #pragma GCC optimize(2)
  2 #include <bits/stdc++.h>
  3 #define int long long
  4 using namespace std;
  5 const int N=150100;
  6 int n,a[N];
  7 int sz[N],vi[N],dfn,MAX,root;
  8 vector <int> e[N];
  9 struct line
 10 {
 11     int k,b;
 12 };
 13 struct node
 14 {
 15     line tag;
 16     int ti;
 17 }sh[N*4];
 18 int cal(int x,line a)
 19 {
 20     return a.k*x+a.b;
 21 }
 22 void change(int x,int l,int r,line k)
 23 {
 24     if (sh[x].ti!=dfn)
 25     {
 26         sh[x].tag=k;
 27         sh[x].ti=dfn;
 28         return;
 29     }
 30     if (cal(l,k)>=cal(l,sh[x].tag) && cal(r,k)>=cal(r,sh[x].tag))
 31     {
 32         sh[x].tag=k;
 33         return;
 34     }
 35     if (cal(l,k)<=cal(l,sh[x].tag) && cal(r,k)<=cal(r,sh[x].tag))
 36       return;
 37     int mid=(l+r)>>1;
 38     if (cal(mid,k)>cal(mid,sh[x].tag)) swap(k,sh[x].tag);
 39     if (cal(l,k)>cal(l,sh[x].tag)) change(x+x,l,mid,k);
 40     else change(x+x+1,mid+1,r,k);
 41 }
 42 int query(int x,int l,int r,int wh)
 43 {
 44     int ans=(sh[x].ti==dfn)?cal(wh,sh[x].tag):0;
 45     if (l==r) return ans;
 46     int mid=(l+r)>>1;
 47     if (wh<=mid) ans=max(ans,query(x+x,l,mid,wh));
 48     else ans=max(ans,query(x+x+1,mid+1,r,wh));
 49     return ans;
 50 }
 51 //李超树
 52 void dfs_insert(int x,int fa,int de,line now)
 53 {
 54     now.k+=a[x];
 55     now.b+=de*a[x];
 56     change(1,1,n,now);
 57     for (register int i=0;i<(int)e[x].size();i++)
 58     {
 59         int u=e[x][i];
 60         if (vi[u] || u==fa) continue;
 61         dfs_insert(u,x,de+1,now);
 62     }
 63 }
 64 void dfs_query(int x,int fa,int de,int sb,int s)
 65 {
 66     sb+=a[x]+s;
 67     s+=a[x];
 68     MAX=max(MAX,query(1,1,n,de)+sb);
 69     for (register int i=0;i<(int)e[x].size();i++)
 70     {
 71         int u=e[x][i];
 72         if (vi[u] || u==fa) continue;
 73         dfs_query(u,x,de+1,sb,s);
 74     }
 75 }
 76 void dfs_size(int x,int fa)
 77 {
 78     sz[x]=1;
 79     for (register int i=0;i<(int)e[x].size();i++)
 80     {
 81         int u=e[x][i];
 82         if (vi[u] || u==fa) continue;
 83         dfs_size(u,x);
 84         sz[x]+=sz[u];
 85     }
 86 }
 87 void dfs_root(int x,int fa,int tot)
 88 {
 89     bool bl=1;
 90     for (register int i=0;i<(int)e[x].size();i++)
 91     {
 92         int u=e[x][i];
 93         if (vi[u] || u==fa) continue;
 94         dfs_root(u,x,tot);
 95         if (sz[u]>tot/2) bl=0;
 96     }
 97     if (tot-sz[x]>tot/2) bl=0;
 98     if (bl) root=x;
 99 }
100 void dfs(int x,int fa,int de,int sa,int sb,int s)
101 {
102     sa+=de*a[x];
103     sb+=s+a[x];
104     s+=a[x];
105     MAX=max(MAX,sb);
106     MAX=max(MAX,sa);
107     for (register int i=0;i<(int)e[x].size();i++)
108     {
109         int u=e[x][i];
110         if (vi[u] || u==fa) continue;
111         dfs(u,x,de+1,sa,sb,s);
112     }
113 }
114 void divide(int x)//点分治
115 {
116     dfn++;
117     vi[x]=1;
118     int cnt=0;
119     for (register int i=0;i<(int)e[x].size();i++)
120     {
121         int u=e[x][i];
122         if (vi[u]) continue;
123         cnt++;
124         line tmp;
125         tmp.k=tmp.b=0;
126         if (cnt!=1) dfs_query(u,x,2,a[x],a[x]);
127         dfs_insert(u,x,1,tmp);
128     }
129     bool bl=(cnt==1);
130     dfn++;cnt=0;
131     for (register int i=(int)e[x].size()-1;i>=0;i--)//反着扫描
132     {
133         int u=e[x][i];
134         if (vi[u]) continue;
135         if (bl) dfs(u,x,2,a[x],a[x],a[x]);//只有一颗子树时的特判
136         cnt++;
137         line tmp;
138         tmp.k=tmp.b=0;
139         if (cnt!=1) dfs_query(u,x,2,a[x],a[x]);
140         dfs_insert(u,x,1,tmp);
141     }
142     for (register int i=0;i<(int)e[x].size();i++)
143     {
144         int u=e[x][i];
145         if (vi[u]) continue;
146         dfs_size(u,x);
147         dfs_root(u,x,sz[u]);
148         divide(root);
149     }
150 }
151 signed main()
152 {
153     scanf("%lld",&n);
154     for (int i=1;i<n;i++)
155     {
156         int u,v;
157         scanf("%lld%lld",&u,&v);
158         e[u].push_back(v);
159         e[v].push_back(u);
160     }
161     for (int i=1;i<=n;i++)
162     {
163         scanf("%lld",&a[i]);
164         MAX=max(MAX,a[i]);
165     }
166     dfs_size(1,-1);
167     dfs_root(1,-1,sz[1]);
168     divide(root);
169     printf("%lld\n",MAX);
170 }
View Code

猜你喜欢

转载自www.cnblogs.com/huangchenyan/p/12358085.html