bzoj 2243: [SDOI2011]染色 (树链剖分+线段树 区间合并)

2243: [SDOI2011]染色

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 9854  Solved: 3725
[Submit][Status][Discuss]

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

 

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

 
思路:
有些区间合并的思想,两个区间不能直接相加,需要比较下他们要相接的两个端点的颜色是否相同,如果相同那么相加的值就要-1,不相同的话直接相加就好了因为是在树上操作,需要用树链剖分处理下,而且更新和查询操作都需要特殊处理下。如果不是在树上操作的话就是一道很简单的线段树了,加了数剖复杂了好多啊。。。
之前lazy标记一直忘了下传,。。找了一天的错。。。
 
实现代码;
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mid ll m = (l+r)>>1
const ll M = 1e5+10;
ll cnt,n,q;
ll siz[M],son[M],fa[M],top[M],rk[M],tid[M],dep[M],a[M],cnt1,head[M],lazy[M<<2];
struct node{ll to,next;}e[M];
struct node1{ll ls,rs,val;};
node1 sum[M<<2];

void add(ll u,ll v){
    e[++cnt1].to = v;e[cnt1].next = head[u];head[u] = cnt1;
    e[++cnt1].to = u;e[cnt1].next = head[v];head[v] = cnt1;
}

void dfs1(ll u,ll faz,ll deep){
     dep[u] = deep;
     fa[u] = faz;
     siz[u] = 1;
     for(ll i = head[u];i;i=e[i].next){
        ll v = e[i].to;
        if(v != fa[u]){
            dfs1(v,u,deep+1);
            siz[u] += siz[v];
            if(son[u] == -1||siz[v] > siz[son[u]])
                son[u] = v;
        }
     }
}

void dfs2(ll u,ll t){
    top[u] = t;
    tid[u] = cnt;
    rk[cnt] = u;
    cnt++;
    if(son[u] == -1) return;
    dfs2(son[u],t);
    for(ll i = head[u];i;i = e[i].next){
        ll v = e[i].to;
        if(v != son[u]&&v != fa[u])
            dfs2(v,v);
    }
}

void pushup(ll rt){
    sum[rt].ls = sum[rt<<1].ls;
    sum[rt].rs = sum[rt<<1|1].rs;
    if(sum[rt<<1].rs==sum[rt<<1|1].ls) sum[rt].val = sum[rt<<1].val+sum[rt<<1|1].val-1;
    else sum[rt].val = sum[rt<<1].val + sum[rt<<1|1].val;
}

void build(ll l,ll r,ll rt){
    if(l == r){
        sum[rt].ls = a[rk[l]];
        sum[rt].rs = a[rk[l]];
        sum[rt].val = 1;
        ////cout<<l<<" "<<rk[l]<<" "<<a[rk[l]]<<endl;
        return ;
    }
    mid;
    build(lson);
    build(rson);
    pushup(rt);
}

void pushdown(ll rt){
    if(lazy[rt]){
    sum[rt<<1].ls = lazy[rt];
    sum[rt<<1|1].rs = lazy[rt];
    sum[rt<<1|1].ls = lazy[rt];
    sum[rt<<1].rs = lazy[rt];
    sum[rt<<1].val = 1;
    sum[rt<<1|1].val = 1;
    lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
    lazy[rt] = 0;
    }
}

void update(ll L,ll R,ll c,ll l,ll r,ll rt){
     if(L <= l&&R >= r){
        sum[rt].val = 1;
        lazy[rt] = c;
        sum[rt].ls = c; sum[rt].rs = c;
        return ;
     }
     pushdown(rt);
     ll m = (l + r) >> 1;
     if(L <= m) update(L,R,c,lson);
     if(R > m)  update(L,R,c,rson);
     pushup(rt);
}

node1 query(ll L,ll R,ll l,ll r,ll rt){
    if(L <= l&&R >= r){
        return sum[rt];
    }
    pushdown(rt);
    ll m = (l + r) >> 1;
    if(L > m) return query(L,R,rson);
    if(R <= m) return query(L,R,lson);
    node1 t1 = query(L,m,lson);
    node1 t2 = query(m+1,R,rson);
    node1 t;
    t.ls = t1.ls;t.rs = t2.rs;
    if(t1.rs==t2.ls) t.val = t1.val+t2.val-1;
    else t.val = t1.val+t2.val;
    return t;
}

void cover(ll x,ll y,ll c){
    ll fx = top[x],fy = top[y];
    while(fx!=fy){
        if(dep[fx] < dep[fy]) swap(fx,fy),swap(x,y);
        update(tid[fx],tid[x],c,1,n,1);
        x = fa[fx];fx = top[x];
    }
    if(dep[x] < dep[y]) swap(x,y);
    update(tid[y],tid[x],c,1,n,1);
}

ll ask(ll x,ll y){
    ll sum = 0;
    ll lc = -1,rc=-1;
    ll fx = top[x],fy = top[y];
    node1 t;
    while(fx != fy){
        if(dep[fx] < dep[fy]){
            swap(fx,fy); swap(x,y); swap(lc,rc);
        }
        t = query(tid[fx],tid[x],1,n,1);
        sum += t.val - (lc==t.rs);
        x = fa[fx]; fx = top[x]; lc = t.ls;
    }
    if(dep[x] < dep[y]) swap(x,y),swap(lc,rc);
    t = query(tid[y],tid[x],1,n,1);
    sum += t.val - (lc==t.rs) - (rc==t.ls);  //当前是x-y区间与两端的区间相加,所以需要判两个
    return sum;
}


int main()
{
    ll u,v,x,y,m,z;
    memset(son,-1,sizeof(son));
    scanf("%lld%lld",&n,&m);
    cnt = 1;cnt1 = 1;
    for(ll i = 1;i <= n;i ++) {
        scanf("%lld",&x);
        a[i] = x+1;
    }
    for(ll i = 0;i < n-1;i++){
        scanf("%lld%lld",&u,&v);
        add(u,v);
    }
    dfs1(1,0,1); dfs2(1,1);
    build(1,n,1);
    char op[10];
    while(m--){
        scanf("%s",op);
        if(op[0] == 'Q'){
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",ask(x,y));
        }
        else {
            scanf("%lld%lld%lld",&x,&y,&z);
            z++;
            cover(x,y,z);
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/kls123/p/8955252.html