[Template] chain split tree

  1 #include<bits/stdc++.h>
  2 #define lson(x) (x << 1)
  3 #define rson(x) (x << 1 | 1)
  4 using namespace std;
  5 
  6 const int N = 100010,M = 200010;
  7 int n,m,r,p;
  8 long long a[N];
  9 int d[N],son[N],fa[N],siz[N],top[N],rk[N],id[N],cnt;//树链剖分 
 10 int head[N],edge[M],nxt[M],from[M],to[M],tot;//邻接表 
 11 void add_edge(int x,int y){
 12     from[++tot] = x;
 13     to[tot] = y;
 14     nxt[tot] = head[x];
 15     head[x] = tot;
 16 }
 17 
 18 //线段树
 19 struct Tree{
 20     int l,r;
 21     long long sum,tag;  
 22 }t[N*4];
 23 void push_up(int x) { t[x].sum = (t[lson(x)].sum + t[rson(x)].sum) % p; }
 24 void push_down(int x){
 25     if(t[x].tag){
 26         t[lson(x)].tag = (t[lson(x)].tag + t[x].tag) % p; 
 27         t[lson(x)].sum = (t[lson(x)].sum + t[x].tag *(t[lson(x)].r - t[lson(x)].l + 1)) % p;
 28         t[rson(x)].tag = (t[rson(x)].tag + t[x].tag) % p; 
 29         t[rson(x)].sum = (t[rson(x)].sum + t[x].tag *(t[rson(x)].r - t[rson(x)].l + 1)) % p;
 30         t[x].tag = 0;
 31     }
 32 }
 33 void build(int x,int ll,int rr){
 34     t[x].l = ll; t[x].r = rr;
 35     if(ll == rr) { 
 36         t[x].sum = a[rk[ll]]; 
 37         if(t[x].sum > p) t[x].sum %= p; 
 38         return; 
 39     }
 40     int mid = (ll + rr) >> 1;
 41     build(lson(x),ll,mid);
 42     build(rson(x),mid + 1,rr);
 43     push_up(x);
 44 }
 45 void update(int x,int ll,int rr,int k){
 46     if(ll <= t[x].l && t[x].r <= rr){
 47         t[x].sum = (t[x].sum + k * (t[x].r - t[x].l + 1)) % p;
 48         t[x].tag = (t[x].tag + k) % p;
 49         return;
 50     }
 51     push_down(x);
 52     int mid = (t[x].l + t[x].r) >> 1;
 53     if(ll <= mid) update(lson(x),ll,rr,k);
 54     if(mid < rr)  update(rson(x),ll,rr,k);
 55     push_up(x);
 56 }
 57 long long query(int x,int ll,int rr){
 58     if(ll <= t[x].l && t[x].r <= rr) return t[x].sum;
 59     int mid = (t[x].l + t[x].r) >> 1;
 60     long long res = 0;
 61     push_down(x);
 62     if(ll <= mid) res = (res + query(lson(x),ll,rr)) % p;
 63     if(mid < rr)  res = (res + query(rson(x),ll,rr)) % p;
 64     push_up(x);
 65     return res;
 66 }
 67 
 68 //树链剖分 
 69 void dfs1(int x){
 70     siz[x] = 1;
 71     for(int i = head[x]; i; i = nxt[i])
 72         if(!siz[to[i]]){
 73             d[to[i]] = d[x] + 1;
 74             fa[to[i]] = x;
 75             dfs1(to[i]);
 76             siz[x] +=  siz[to[i]];
 77             if(!son[x] || siz[to[i]] > siz[son[x]])
 78                 son[x] = to[i];
 79         }
 80 }
 81 void dfs2(int x,int t){
 82     top[x] = t;
 83     id[x] = ++cnt;
 84     rk[cnt] = x;
 85     if(!son[x]) return;
 86     dfs2(son[x],t);
 87     for(int i = head[x]; i; i = nxt[i])
 88         if(to[i] != son[x] && to[i] != fa[x])
 89             dfs2(to[i],to[i]);
 90 }
 91 int main(){
 92 //    freopen("testdata (6).in","r",stdin);
 93 //    freopen("test.txt","w",stdout);
 94     cin >> n >> m >> r >> p;
 95     for(int i = 1; i <= n; i++)
 96         scanf("%lld",&a[i]);
 97     for(int i = 1; i < n; i++){
 98         int x,y;
 99         scanf("%d%d",&x,&y);
100         add_edge(x,y);
101         add_edge(y,x);
102     }
103     dfs1(r);
104     dfs2(r,r);
105     build(1,1,n);
106     for(int i = 1; i <= m; i++){
107         int op = 0,x = 0,y = 0,z = 0;
108         scanf("%d%d",&op,&x);
109         if(op == 1){
110             scanf("%d%d",&y,&z);
111             while(top[x] != top[y]){
112                 if(d[top[y]] > d[top[x]]) swap(x,y);
113                 update(1,id[top[x]],id[x],z);
114                 x = fa[top[x]];
115             }
116             if(d[x] > d[y]) swap(x,y);
117                 update(1,id[x],id[y],z);
118         }
119         if(op == 2){
120             scanf("%d",&y);
121             long long ans = 0;
122             while(top[x] != top[y]){
123                 if(d[top[y]] > d[top[x]]) swap(x,y);
124                 ans = (ans + query(1,id[top[x]],id[x])) % p;
125                 x = fa[top[x]];
126             }
127             if(d[x] > d[y]) swap(x,y);
128             ans = (ans + query(1,id[x],id[y]))% p;
129             printf("%lld\n",ans);
130         }
131         if(op == 3){
132             scanf("%d",&z);
133             update(1,id[x],id[x] + siz[x] - 1,z);
134         }
135         if(op == 4){
136             printf("%lld\n",query(1,id[x],id[x] + siz[x] - 1));
137         } 
138     }
139     return 0;
140 }

 

Guess you like

Origin www.cnblogs.com/FoxC/p/11222392.html