[SDOI2011][BZOJ2243] 染色 - 树链剖分

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。

请你写一个程序依次完成这m个操作。

Input & Output

Input

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面n - 1行每行包含两个整数x和y,表示x和y之间有一条无向边。

下面m行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample

Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Output

3
1
2

Solution

维护查询树上路径,可以想到树链剖分,不过在维护答案的时候需要费一点是,如何使线段树满足区间加和性质呢?我们可以记录每个区间最左的颜色和最右的颜色,合并时,如果左儿子的右端点和右儿子的左端点相同,那么ans[x] = ans[lc] + ans[rc] - 1,否则就不减1,因为如果相同则一定会少一个颜色段。这一点在pushup和query的时候都要得到体现。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

using std :: min;
using std :: max;
using std :: swap;
using std :: cin;
using std :: cout;
using std :: endl;
using std :: ios;
using std :: memset;

const int maxn = 100005;

struct node
{
    int tot,lst,rst,tag;
}t[maxn << 2];
struct edge
{
    int to,nxt;
}e[maxn << 1];
int n,m,c[maxn],f[maxn],dfn[maxn],son[maxn],top[maxn],size[maxn],dep[maxn];
int lnk[maxn],edgenum,u,v,cnt;
int w[maxn];

void add(int bgn,int end)
{
    e[++edgenum].to = end;
    e[edgenum].nxt = lnk[bgn];
    lnk[bgn] = edgenum;
}
void dfs(int x,int fa,int d)
{
    size[x] = 1;
    f[x] = fa;
    dep[x] = d;
    for(int p = lnk[x]; p; p = e[p].nxt)
    {
        int y = e[p].to;
        if(y == fa)continue;
        dfs(y, x, d + 1);
        size[x] += size[y];
        if(size[y] > size[son[x]]) son[x] = y;
    }
}
void dfs2(int x,int init)
{
    dfn[x] = ++cnt;
    w[cnt] = c[x];
    top[x] = init;
    if(!son[x])return;
    dfs2(son[x],init);
    for(int p = lnk[x]; p; p = e[p].nxt)
    {
        int y = e[p].to;
        if(y == f[x]||y == son[x])continue;
        dfs2(y,y);
    }
}
void pushdown(int cur)
{
    t[cur << 1].tot = t[cur << 1|1].tot = 1;
    t[cur << 1].tag = t[cur << 1|1].tag = t[cur].tag;
    t[cur << 1].lst = t[cur << 1].rst = t[cur].tag;
    t[cur << 1|1].lst = t[cur << 1|1].rst = t[cur].tag;
    t[cur].tag = 0;
}
void pushup(int cur)
{
    t[cur].tot = t[cur << 1].tot + t[cur << 1|1].tot;
    t[cur].lst = t[cur << 1].lst;
    t[cur].rst = t[cur << 1|1].rst;
    if(t[cur << 1].rst == t[cur << 1|1].lst) t[cur].tot--; //这里
    
}
void build(int cur,int l,int r)
{
    if(l == r)
    {
        t[cur].tot = 1;
        t[cur].lst = t[cur].rst = w[l];
        t[cur].tag = 0;
        return;
    }
    int mid = (l + r) >> 1;
    build(cur << 1, l, mid);
    build(cur << 1|1, mid+1, r);
    pushup(cur);
}
int query(int cur,int l,int r,int L,int R)
{
    int res = 0;
    if(L <= l && r <= R) return t[cur].tot;
    if(t[cur].tag) pushdown(cur);
    int mid = (l + r) >> 1;
    if(L <= mid) res += query(cur << 1, l, mid, L, R);
    if(R > mid) res += query(cur << 1|1, mid+1, r, L, R);
    if(mid >= L && mid < R && t[cur << 1].rst == t[cur << 1|1].lst) res--; //这里
    return res;
}
void update(int cur,int l,int r,int L,int R,int x)
{
    if(L <= l && r <= R)
    {
        t[cur].tot = 1;
        t[cur].lst = t[cur].rst = x;
        t[cur].tag = x;
        return;
    }
    if(t[cur].tag) pushdown(cur);
    int mid = (l + r) >> 1;
    if(L <= mid)update(cur<<1,l,mid,L,R,x);
    if(R > mid)update(cur<<1|1,mid+1,r,L,R,x);
    pushup(cur);
}
int check(int cur,int l,int r,int pos)
{
    if(l == r)return t[cur].lst;
    if(t[cur].tag) pushdown(cur);
    int mid = (l + r) >> 1;
    if(pos > mid) return check(cur << 1|1, mid+1, r, pos);
    else return check(cur << 1,l,mid,pos);
}
int queryt(int x,int y)
{
    int ans = 0, b1,b2;
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]])swap(x,y);
        ans += query(1,1,n,dfn[top[x]],dfn[x]);
        b1 = check(1,1,n,dfn[top[x]]);
        b2 = check(1,1,n,dfn[f[top[x]]]); //和这里
        if(b1 == b2) ans--;
        x = f[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    ans += query(1,1,n,dfn[x],dfn[y]);
    return ans;
 } 
void updt(int x,int y,int c)
{
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]])swap(x,y);
        update(1,1,n,dfn[top[x]],dfn[x],c);
        x = f[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    update(1,1,n,dfn[x],dfn[y],c);
}

int main()
{
    ios :: sync_with_stdio(false);
    char opt;
    int x,y,z;
    cin >> n >> m;
    for(int i = 1; i <= n; ++i)
        cin >> c[i];
    for(int i = 1; i < n; ++i)
    {
        cin >> u >> v;
        add(u,v);
        add(v,u);
    }
    dfs(1,0,1);
    dfs2(1,1);
    build(1,1,n);
    for(int i = 1; i <= m; ++i)
    {
        cin >> opt;
        if(opt == 'Q')
        {
            cin >> x >> y;
            cout << queryt(x,y) << endl;
        }
        else
        {
            cin >> x >> y >> z;
            updt(x, y, z);
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/nishikino-curtis/p/9047692.html