题目描述
输入输出格式
输入格式:
输出格式:
对于每个询问操作,输出一行答案。
输入输出样例
输入样例#1:
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
输出样例#1:
3
1
2
说明
题解
树链剖分进阶题,思路好懂,细节多,难调
树上修改一段区间的值,很容易想到树链剖分,那么修改操作就是常规的使用线段树来实现,但是这里有一些细节。
build
怎么实现建树操作呢,我们需要两个数组s[]和t[],代表当前操作区间的左端点的颜色和右端点的颜色,那么当递归到(l==r)时,将s[]及t[]全部赋为当前颜色,并把当前区间不同颜色个数sum赋值为1(因为只有一个颜色嘛).
Code
void build(int node,int l,int r) {
if(l==r) {
sum[node]=1;
s[node]=t[node]=c[l];
return;
}
int mid=l+r>>1;
build(node<<1,l,mid);
build(node<<1|1,mid+1,r);
push_up(node,l,r);
}
(代码中的push_up函数会在下面讲到)
update
修改操作大致和以前一样,就是当(当前操作区间在修改区间内时),把当前区间的左右端点的颜色都赋值为需要修改的颜色k,个数计为1,并打上懒惰标记
Code
void update(int node,int l,int r,int left,int right,int k) {
if(l > right || r < left) return;
if(left<=l && r<=right) {
s[node]=t[node]=k;
sum[node]=1;
lazy[node]=k;
return;
}
int mid=l+r>>1;
if(lazy[node]) push_down(node,l,r);
if(left <= mid) update(node<<1,l,mid,left,right,k);
if(mid < right) update(node<<1|1,mid+1,r,left,right,k);
push_up(node,l,r);
}
push_up
即统计一下当前区间的个数,有个细节要注意,下面我们来分类讨论一下
<1> 左区间 123 右区间 221 那么此段区间颜色总个数应为左区间颜色个数3+右区间个数2=5,这是第一种情况,没有问题
<2>左区间 122 右区间 231 此时此段区间颜色个数应为左区间个数2+右区间颜色个数3-1(中间有连续的颜色2)=4
相信大家也发现了,我们在统计个数的时候需要记录一下t[左区间]是否等于s[右区间],如果等于总个数就要-1
Code
inline void push_up(int node,int l,int r) {
s[node]=s[node<<1];
t[node]=t[node<<1|1];
sum[node]=sum[node<<1]+sum[node<<1|1];
if(t[node<<1]==s[node<<1|1]) sum[node]--;
}
check
也没什么好说的,大部分一样,也是要判断当t[左区间]==s[右区间]时,ans- -
Code
int check(int node,int l,int r,int left,int right) {
if(l>right || r<left) return 0;
if(left<=l && r<=right) return sum[node];
if(lazy[node]) push_down(node,l,r);
int mid=l+r>>1, ans=0;
if(left<=mid) ans+=check(node<<1,l,mid,left,right);
if(mid<right) ans+=check(node<<1|1,mid+1,r,left,right);
if(left<=mid && mid<right && s[node<<1|1]==t[node<<1]) ans--;
return ans;
}
好了,线段树操作就是这些.接下来讲一些树链剖分部分需要注意的地方
思考一下,我们在普通的树链剖分查询操作中,运用了一种类似于lca的思想,每次将当前查询的点跳到所在链的链顶,还是那个问题,我们每次ans+的都是一段区间的答案,但是当col[top[x]]==col[fa[top[x]]]时,我们依然会将两段区间的值都加上,但实际上此时应该ans- -,但是在线段树中,我们并没有单独记录一个节点的颜色,所以我单独写了个函数来查询
Code
int find(int node,int l,int r,int pos) {//pos为当前需要查询颜色的节点
if(l == r) return s[node];
if(lazy[node]) push_down(node,l,r);
int mid=l+r>>1;
if(pos<=mid) return find(node<<1,l,mid,pos);
if(pos>=mid+1) return find(node<<1|1,mid+1,r,pos);
}
这道题是一道好题,既能帮助你理解线段树又能更好的学习树链剖分,就是比较难调,值得一做
AC Code
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define in(i) (i=read())
using namespace std;
int read(){
int ans=0,f=1; char i=getchar();
while(i<'0' || i>'9'){if(i=='-') f=-1;i=getchar();}
while(i>='0' && i<='9'){ans=(ans<<3)+(ans<<1)+i-'0';i=getchar();}
return ans*f;
}
const int N=1e5+10;
struct edge {
int to,next;
}e[N*2];
int n,m,len,cnt;
int dep[N],dfn[N],son[N],size[N],fa[N],head[N],top[N],col[N];
int sum[N*4],s[N*4],t[N*4],c[N*4],lazy[N*4];
void add(int a,int b) {
e[++len].to=b; e[len].next=head[a];
head[a]=len;
}
void dfs1(int u) {
size[u]=1;
for(int i=head[u];i;i=e[i].next) {
int to=e[i].to;
if(!dep[to]) {
dep[to]=dep[u]+1;
fa[to]=u; dfs1(to);
size[u]+=size[to];
if(size[to]>size[son[u]]) son[u]=to;
}
}
}
void dfs2(int u,int t) {
top[u]=t,dfn[u]=++cnt,c[cnt]=col[u];
if(son[u]) dfs2(son[u],t);
for(int i=head[u];i;i=e[i].next) {
int to=e[i].to;
if(to!=fa[u] && to!=son[u]) dfs2(to,to);
}
}
inline void push_up(int node,int l,int r) {
s[node]=s[node<<1];
t[node]=t[node<<1|1];
sum[node]=sum[node<<1]+sum[node<<1|1];
if(t[node<<1]==s[node<<1|1]) sum[node]--;
}
inline void push_down(int node,int l,int r) {
lazy[node<<1]=lazy[node];
lazy[node<<1|1]=lazy[node];
s[node<<1]=t[node<<1]=lazy[node];
s[node<<1|1]=t[node<<1|1]=lazy[node];
sum[node<<1]=sum[node<<1|1]=1;
lazy[node]=0;
}
void build(int node,int l,int r) {
if(l==r) {
sum[node]=1;
s[node]=t[node]=c[l];
return;
}
int mid=l+r>>1;
build(node<<1,l,mid);
build(node<<1|1,mid+1,r);
push_up(node,l,r);
}
void update(int node,int l,int r,int left,int right,int k) {
if(l > right || r < left) return;
if(left<=l && r<=right) {
s[node]=t[node]=k;
sum[node]=1;
lazy[node]=k;
return;
}
int mid=l+r>>1;
if(lazy[node]) push_down(node,l,r);
if(left <= mid) update(node<<1,l,mid,left,right,k);
if(mid < right) update(node<<1|1,mid+1,r,left,right,k);
push_up(node,l,r);
}
int check(int node,int l,int r,int left,int right) {
if(l>right || r<left) return 0;
if(left<=l && r<=right) return sum[node];
if(lazy[node]) push_down(node,l,r);
int mid=l+r>>1, ans=0;
if(left<=mid) ans+=check(node<<1,l,mid,left,right);
if(mid<right) ans+=check(node<<1|1,mid+1,r,left,right);
if(left<=mid && mid<right && s[node<<1|1]==t[node<<1]) ans--;
return ans;
}
int find(int node,int l,int r,int pos) {
if(l == r) return s[node];
if(lazy[node]) push_down(node,l,r);
int mid=l+r>>1;
if(pos<=mid) return find(node<<1,l,mid,pos);
if(pos>=mid+1) return find(node<<1|1,mid+1,r,pos);
}
int main(){
in(n);in(m);
for(int i=1;i<=n;i++) in(col[i]);
for(int i=1;i<n;i++) {
int a,b;
in(a); in(b);
add(a,b); add(b,a);
}
dep[1] = 1; dfs1(1); dfs2(1,0); build(1,1,n);
char op[3];int a,b,x;
for(int i=1;i<=m;i++) {
scanf("%s",op);
if(op[0]=='C') {
in(a);in(b);in(x);
int fx=top[a],fy=top[b];
while(fx!=fy) {
if(dep[fx]<dep[fy])
swap(a,b), swap(fx,fy);
update(1,1,n,dfn[fx],dfn[a],x);
a=fa[fx];fx=top[a];
}
if(dep[a]>dep[b]) swap(a,b);
update(1,1,n,dfn[a],dfn[b],x);
}
else {
in(a);in(b);
int ans=0, fx=top[a],fy=top[b];
while(fx!=fy) {
if(dep[fx]<dep[fy]) swap(a,b), swap(fx,fy);
ans+=check(1,1,n,dfn[fx],dfn[a]);
if(find(1,1,n,dfn[fx]) == find(1,1,n,dfn[fa[fx]])) ans--;
a=fa[fx];fx=top[a];
}
if(dep[a]>dep[b]) swap(a,b);
ans+=check(1,1,n,dfn[a],dfn[b]);
printf("%d\n",ans);
}
}
return 0;
}