[树链剖分] SDOI2011 染色

题目描述

输入输出格式

输入格式:

输出格式:

对于每个询问操作,输出一行答案。

输入输出样例

输入样例#1:

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

输出样例#1:

3
1
2

说明

题解

  • 树链剖分进阶题,思路好懂,细节多,难调

  • 树上修改一段区间的值,很容易想到树链剖分,那么修改操作就是常规的使用线段树来实现,但是这里有一些细节。

build

怎么实现建树操作呢,我们需要两个数组s[]和t[],代表当前操作区间的左端点的颜色和右端点的颜色,那么当递归到(l==r)时,将s[]及t[]全部赋为当前颜色,并把当前区间不同颜色个数sum赋值为1(因为只有一个颜色嘛).

Code

void build(int node,int l,int r) {
  if(l==r) {
    sum[node]=1;
    s[node]=t[node]=c[l];
    return;
  }
  int mid=l+r>>1;
  build(node<<1,l,mid);
  build(node<<1|1,mid+1,r);
  push_up(node,l,r);
}

(代码中的push_up函数会在下面讲到)

update

修改操作大致和以前一样,就是当(当前操作区间在修改区间内时),把当前区间的左右端点的颜色都赋值为需要修改的颜色k,个数计为1,并打上懒惰标记

Code

void update(int node,int l,int r,int left,int right,int k) {
  if(l > right || r < left) return;
  if(left<=l && r<=right) {
    s[node]=t[node]=k;
    sum[node]=1;
    lazy[node]=k;
    return;
  }
  int mid=l+r>>1;
  if(lazy[node]) push_down(node,l,r);
  if(left <= mid) update(node<<1,l,mid,left,right,k);
  if(mid < right) update(node<<1|1,mid+1,r,left,right,k);
  push_up(node,l,r);
}

push_up

即统计一下当前区间的个数,有个细节要注意,下面我们来分类讨论一下

<1> 左区间 123 右区间 221 那么此段区间颜色总个数应为左区间颜色个数3+右区间个数2=5,这是第一种情况,没有问题

<2>左区间 122 右区间 231 此时此段区间颜色个数应为左区间个数2+右区间颜色个数3-1(中间有连续的颜色2)=4

相信大家也发现了,我们在统计个数的时候需要记录一下t[左区间]是否等于s[右区间],如果等于总个数就要-1

Code

inline void push_up(int node,int l,int r) {
  s[node]=s[node<<1];
  t[node]=t[node<<1|1];
  sum[node]=sum[node<<1]+sum[node<<1|1];
  if(t[node<<1]==s[node<<1|1]) sum[node]--;
}

check

也没什么好说的,大部分一样,也是要判断当t[左区间]==s[右区间]时,ans- -

Code

int check(int node,int l,int r,int left,int right) {
  if(l>right || r<left) return 0;
  if(left<=l && r<=right) return sum[node];
  if(lazy[node]) push_down(node,l,r);
  int mid=l+r>>1, ans=0;
  if(left<=mid) ans+=check(node<<1,l,mid,left,right);
  if(mid<right) ans+=check(node<<1|1,mid+1,r,left,right);
  if(left<=mid && mid<right && s[node<<1|1]==t[node<<1]) ans--;
  return ans;
}

好了,线段树操作就是这些.接下来讲一些树链剖分部分需要注意的地方

思考一下,我们在普通的树链剖分查询操作中,运用了一种类似于lca的思想,每次将当前查询的点跳到所在链的链顶,还是那个问题,我们每次ans+的都是一段区间的答案,但是当col[top[x]]==col[fa[top[x]]]时,我们依然会将两段区间的值都加上,但实际上此时应该ans- -,但是在线段树中,我们并没有单独记录一个节点的颜色,所以我单独写了个函数来查询

Code

int find(int node,int l,int r,int pos) {//pos为当前需要查询颜色的节点
  if(l == r) return s[node];
  if(lazy[node]) push_down(node,l,r);
  int mid=l+r>>1;
  if(pos<=mid) return find(node<<1,l,mid,pos);
  if(pos>=mid+1) return find(node<<1|1,mid+1,r,pos);
}

这道题是一道好题,既能帮助你理解线段树又能更好的学习树链剖分,就是比较难调,值得一做

AC Code

// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define in(i) (i=read())
using namespace std;
int read(){
  int ans=0,f=1; char i=getchar();
  while(i<'0' || i>'9'){if(i=='-') f=-1;i=getchar();}
  while(i>='0' && i<='9'){ans=(ans<<3)+(ans<<1)+i-'0';i=getchar();}
  return ans*f;
}
const int N=1e5+10;
struct edge {
  int to,next;
}e[N*2];

int n,m,len,cnt;
int dep[N],dfn[N],son[N],size[N],fa[N],head[N],top[N],col[N];
int sum[N*4],s[N*4],t[N*4],c[N*4],lazy[N*4];

void add(int a,int b) {
  e[++len].to=b; e[len].next=head[a];
  head[a]=len;
}

 void dfs1(int u) {
  size[u]=1;
  for(int i=head[u];i;i=e[i].next) {
    int to=e[i].to;
    if(!dep[to]) {
      dep[to]=dep[u]+1;
      fa[to]=u; dfs1(to);
      size[u]+=size[to];
      if(size[to]>size[son[u]]) son[u]=to;
    }
  }
}

void dfs2(int u,int t) {
  top[u]=t,dfn[u]=++cnt,c[cnt]=col[u];
  if(son[u]) dfs2(son[u],t);
  for(int i=head[u];i;i=e[i].next) {
    int to=e[i].to;
    if(to!=fa[u] && to!=son[u]) dfs2(to,to);
  }
}

inline void push_up(int node,int l,int r) {
  s[node]=s[node<<1];
  t[node]=t[node<<1|1];
  sum[node]=sum[node<<1]+sum[node<<1|1];
  if(t[node<<1]==s[node<<1|1]) sum[node]--;
}

inline void push_down(int node,int l,int r) {
  lazy[node<<1]=lazy[node];
  lazy[node<<1|1]=lazy[node];
  s[node<<1]=t[node<<1]=lazy[node];
  s[node<<1|1]=t[node<<1|1]=lazy[node];
  sum[node<<1]=sum[node<<1|1]=1;
  lazy[node]=0;
}

void build(int node,int l,int r) {
  if(l==r) {
    sum[node]=1;
    s[node]=t[node]=c[l];
    return;
  }
  int mid=l+r>>1;
  build(node<<1,l,mid);
  build(node<<1|1,mid+1,r);
  push_up(node,l,r);
}

void update(int node,int l,int r,int left,int right,int k) {
  if(l > right || r < left) return;
  if(left<=l && r<=right) {
    s[node]=t[node]=k;
    sum[node]=1;
    lazy[node]=k;
    return;
  }
  int mid=l+r>>1;
  if(lazy[node]) push_down(node,l,r);
  if(left <= mid) update(node<<1,l,mid,left,right,k);
  if(mid < right) update(node<<1|1,mid+1,r,left,right,k);
  push_up(node,l,r);
}

int check(int node,int l,int r,int left,int right) {
  if(l>right || r<left) return 0;
  if(left<=l && r<=right) return sum[node];
  if(lazy[node]) push_down(node,l,r);
  int mid=l+r>>1, ans=0;
  if(left<=mid) ans+=check(node<<1,l,mid,left,right);
  if(mid<right) ans+=check(node<<1|1,mid+1,r,left,right);
  if(left<=mid && mid<right && s[node<<1|1]==t[node<<1]) ans--;
  return ans;
}

int find(int node,int l,int r,int pos) {
  if(l == r) return s[node];
  if(lazy[node]) push_down(node,l,r);
  int mid=l+r>>1;
  if(pos<=mid) return find(node<<1,l,mid,pos);
  if(pos>=mid+1) return find(node<<1|1,mid+1,r,pos);
}

int main(){
  in(n);in(m);
  for(int i=1;i<=n;i++) in(col[i]);
  for(int i=1;i<n;i++) {
    int a,b;
    in(a); in(b);
    add(a,b); add(b,a);
  }
  dep[1] = 1;  dfs1(1); dfs2(1,0); build(1,1,n);
  char op[3];int a,b,x;
  for(int i=1;i<=m;i++) {
    scanf("%s",op);
    if(op[0]=='C') {
      in(a);in(b);in(x);
      int fx=top[a],fy=top[b];
      while(fx!=fy) {
    if(dep[fx]<dep[fy])
      swap(a,b), swap(fx,fy);
    update(1,1,n,dfn[fx],dfn[a],x);
    a=fa[fx];fx=top[a];
      }
      if(dep[a]>dep[b]) swap(a,b);
      update(1,1,n,dfn[a],dfn[b],x);
    }
    else {
      in(a);in(b);
      int ans=0, fx=top[a],fy=top[b];
      while(fx!=fy) {
    if(dep[fx]<dep[fy]) swap(a,b), swap(fx,fy);
    ans+=check(1,1,n,dfn[fx],dfn[a]);
    if(find(1,1,n,dfn[fx]) == find(1,1,n,dfn[fa[fx]])) ans--;
    a=fa[fx];fx=top[a];
      }
      if(dep[a]>dep[b]) swap(a,b);
      ans+=check(1,1,n,dfn[a],dfn[b]);
      printf("%d\n",ans);
    }
  }
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/real-l/p/9230754.html