[树链剖分][线段树] Jzoj P4388 染色

Description

 

Input

Output

 

Sample Input

4 6
0 1 2
2 1 3
2 2
1 3
2 2
2 3
2 1
1 3

Sample Output

0
3
0
4
 

Data Constraint

题解

  • 终于没有打水法辽(不过调的我心态爆炸)
  • 我们可以发现,询问节点x到所有黑点的距离和:∑dis[i]+dis[x]-2*dis[lca(i,x)]=∑dis[i]+cnt*dis[x]-2*∑dis[lca(i,x)] (cnt为黑点数)
  • 然后所有的黑点i的dis[i]和很好求,询问点x的和为dis[x]∗n,也是很好求的,关键就是∑dis[lca(i,x)] 
  • 我们可以把询问点x到根的路径上的权值设为d[i],其他的设为0
  • 那么所有黑点i到根的路径权值和就是∑dis[lca(i,x)]
  • 换种想法,可以把所有点到根的路径的每条边权值都加上为d[i],那么x到根的路径和也是∑dis[lca(i,x)]
  • 设每条边的权值设为d[i]∗k,我们就相当于每次给一条路径上的所有k加1,树剖+线段树维护即可

     

代码

 1 #include <cstdio>
 2 #include <iostream>
 3 #include <cstring>
 4 #include <cmath>
 5 #define ll long long
 6 using namespace std;
 7 const ll N=2e5+10,M=110;
 8 struct tree{ll s,l,r,w,v;}t[N*4];
 9 struct edge{ll to,from;}e[N];
10 ll dist[N],lazy[N*4];
11 ll n,m,L,R,p,op,cnt,tot,num,ans,sum,v[N],size[N],top[N],son[N],fa[N],id[N],rank[N],head[N];
12 bool col[N];
13 void insert(ll x,ll y) { e[++cnt].to=y,e[cnt].from=head[x],head[x]=cnt; }
14 void dfs1(ll x)
15 {
16     size[x]=1;
17     for (ll i=head[x];i;i=e[i].from)
18         if (e[i].to!=fa[x])
19         {
20             fa[e[i].to]=x,dist[e[i].to]=dist[x]+v[e[i].to],dfs1(e[i].to),size[x]+=size[e[i].to];
21             if (size[e[i].to]>size[son[x]]) son[x]=e[i].to;
22         }
23 }
24 void dfs2(ll x,ll pre)
25 {
26     id[x]=++tot,rank[tot]=x,top[x]=pre;
27     if (son[x]) dfs2(son[x],pre);
28     for (ll i=head[x];i;i=e[i].from) if (e[i].to!=fa[x]&&e[i].to!=son[x]) dfs2(e[i].to,e[i].to); 
29 }
30 void build(ll d,ll l,ll r)
31 {
32     t[d].l=l,t[d].r=r;
33     if (l==r) 
34     {
35         t[d].w=lazy[d]=t[d].s=0,t[d].v=v[rank[l]];
36         return;
37     }
38     ll mid=l+r>>1;
39     build(d*2,l,mid),build(d*2+1,mid+1,r),t[d].v=t[d*2].v+t[d*2+1].v;
40 }
41 void add(ll x,ll k) { t[x].s+=t[x].v*k,t[x].w+=k,lazy[x]+=k; }
42 void downdate(ll x)
43 {
44     if (!lazy[x]) return;
45     add(x*2,lazy[x]),add(x*2+1,lazy[x]),lazy[x]=0;
46 }
47 void update(ll x) { t[x].s=t[x*2].s+t[x*2+1].s,t[x].w=t[x*2].w+t[x*2+1].w; }
48 void work(ll d)
49 {
50     downdate(d);
51     if (L<=t[d].l&&t[d].r<=R)
52     {
53         if (op==1) add(d,1); else p+=t[d].s;
54         return;
55     }
56     ll mid=t[d].l+t[d].r>>1;
57     if (L<=mid) work(d*2); 
58     if (mid<R) work(d*2+1);
59     update(d);
60 }
61 void find(ll x) { for (;x;x=fa[top[x]]) L=id[top[x]],R=id[x],work(1); }
62 int main()
63 {
64     scanf("%lld%lld",&n,&m);
65     for (ll i=2,x;i<=n;i++) scanf("%lld",&x),insert(x+1,i);
66     for (ll i=2;i<=n;i++) scanf("%lld",&v[i]);
67     dfs1(1),dfs2(1,1),build(1,1,n),memset(col,1,sizeof(col));
68     for (ll i=1,x;i<=m;i++)
69     {
70         scanf("%lld%lld",&op,&x),x++;
71         if (op==1) 
72         {
73             if (col[x]) find(x),num++,sum+=dist[x];
74             col[x]=0;
75         }
76         else p=0,find(x),ans=sum-2*p+dist[x]*num,printf("%lld\n",ans);
77     }
78 }

猜你喜欢

转载自www.cnblogs.com/Comfortable/p/11126995.html