P2486 [SDOI2011]染色 区间合并+树链剖分(加深对线段树的理解)

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 const  int M=3e5+5;
  4 struct node{
  5     int l,r,cnt,lazy;
  6     node(int l1=0,int r1=0,int cnt1=0,int lazy1=0):l(l1),r(r1),cnt(cnt1),lazy(lazy1){}
  7 }tree[M<<2];
  8 int fa[M],sz[M],deep[M],dfn[M],son[M],to[M],a[M],top[M],cnt,n;
  9 char s[2];
 10 vector<int>g[M];
 11 void dfs1(int u,int from){
 12     fa[u]=from;
 13     sz[u]=1;
 14     deep[u]=deep[from]+1;
 15     for(int i=0;i<g[u].size();i++){
 16 
 17         int v=g[u][i];
 18         if(v!=from){
 19             dfs1(v,u);
 20             sz[u]+=sz[v];
 21             if(sz[v]>sz[son[u]])
 22                 son[u]=v;
 23         }
 24         
 25     }
 26 }
 27 void dfs2(int u,int t){
 28     top[u]=t;
 29     dfn[u]=++cnt;
 30     to[cnt]=u;
 31     if(!son[u])
 32         return ;
 33     dfs2(son[u],t);
 34     for(int i=0;i<g[u].size();i++){
 35         int v=g[u][i];
 36         if(v!=fa[u]&&v!=son[u])
 37             dfs2(v,v);
 38     }
 39 }
 40 void up(int root){
 41     tree[root].cnt=tree[root<<1].cnt+tree[root<<1|1].cnt;
 42     if(tree[root<<1].r==tree[root<<1|1].l)
 43         tree[root].cnt--;
 44     tree[root].l=tree[root<<1].l;
 45     tree[root].r=tree[root<<1|1].r;
 46 }
 47 void build(int root,int l,int r){
 48     tree[root].lazy=0;
 49     if(l==r){
 50         tree[root].l=tree[root].r=a[to[l]];
 51         tree[root].cnt=1;
 52         return ;
 53     }
 54     int midd=(l+r)>>1;
 55     build(root<<1,l,midd);
 56     build(root<<1|1,midd+1,r);
 57     up(root);
 58 }
 59 void pushdown(int root){
 60     tree[root<<1]=tree[root<<1|1]=node(tree[root].l,tree[root].r,1,tree[root].lazy);
 61     tree[root].lazy=0;
 62 }
 63 void update(int L,int R,int x,int root,int l,int r){
 64     if(L<=l&&r<=R){
 65         tree[root]=node(x,x,1,x);
 66         return ;
 67     }
 68     if(tree[root].lazy)
 69         pushdown(root);
 70     int midd=(l+r)>>1;
 71     if(L<=midd)
 72         update(L,R,x,root<<1,l,midd);
 73     if(R>midd)
 74         update(L,R,x,root<<1|1,midd+1,r);
 75     up(root);
 76 }
 77 void add(int u,int v ,int w){
 78     int fu=top[u],fv=top[v];
 79     while(fu!=fv){
 80         if(deep[fu]>=deep[fv])
 81             update(dfn[fu],dfn[u],w,1,1,n),u=fa[fu],fu=top[u];
 82         else
 83             update(dfn[fv],dfn[v],w,1,1,n),v=fa[fv],fv=top[v];
 84     }
 85     if(dfn[u]<=dfn[v])
 86         update(dfn[u],dfn[v],w,1,1,n);
 87     else
 88         update(dfn[v],dfn[u],w,1,1,n);
 89 }
 90 node meger(node a,node b){
 91     if(!a.cnt)
 92         return b;
 93     if(!b.cnt)
 94         return a;
 95     node ans=node(0,0,0,0);
 96     ans.cnt=a.cnt+b.cnt;
 97     if(a.r==b.l)
 98         ans.cnt--;
 99     ans.l=a.l;
100     ans.r=b.r;
101     return ans;
102 }
103 node query(int L,int R,int root,int l,int r){
104     if(L<=l&&r<=R){
105         return tree[root];
106     }
107     if(tree[root].lazy)
108         pushdown(root);
109     int midd=(l+r)>>1;
110     node ans;
111     if(L<=midd)
112         ans=query(L,R,root<<1,l,midd);
113     if(R>midd)
114         ans=meger(ans,query(L,R,root<<1|1,midd+1,r));
115     up(root);
116     return ans;
117 }
118 int solve(int u,int v){
119     node l,r;
120     int fv=top[v],fu=top[u];
121     while(fv!=fu){
122         if(deep[fu]>=deep[fv])
123             l=meger(query(dfn[fu],dfn[u],1,1,n),l),u=fa[fu],fu=top[u];
124         else
125             r=meger(query(dfn[fv],dfn[v],1,1,n),r),v=fa[fv],fv=top[v];
126     }
127     if(dfn[u]<=dfn[v])
128         r=meger(query(dfn[u],dfn[v],1,1,n),r);
129     else
130         l=meger(query(dfn[v],dfn[u],1,1,n),l);
131     swap(l.l,l.r);
132     l=meger(l,r);
133     return l.cnt;
134 }
135 int main(){
136     int m;
137     scanf("%d%d",&n,&m);
138     for(int i=1;i<=n;i++)
139         scanf("%d",&a[i]);
140     for(int i=1;i<n;i++){
141         int u,v;
142         scanf("%d%d",&u,&v);
143         g[u].push_back(v);
144         g[v].push_back(u);
145     }//cout<<"!!"<<endl;
146     dfs1(1,1);
147     dfs2(1,1);
148     
149     build(1,1,n);
150     while(m--){
151         int u,v,w;
152         scanf("%s",s);
153         if(s[0]=='Q'){
154             scanf("%d%d",&u,&v);
155             printf("%d\n",solve(u,v));
156         }
157         else{
158             scanf("%d%d%d",&u,&v,&w);
159             add(u,v,w);
160         }
161     }
162     return 0;
163 }
View Code

猜你喜欢

转载自www.cnblogs.com/starve/p/10840183.html