一道比较综合的题,看了题解才会做。
先观察问题的性质:题目所给的染色方式,最终会导致一条链上颜色一定是连续的一段段的出现,考虑用 LCT 维护颜色相同的链根本没有想到,每一棵 splay 维护的是原树上颜色相同的一段链,由于每次染色直接染到根,这相当于 操作,那么每当有操作1时,就 ,这样 LCT 的 操作不能在其它地方用了,否则会破坏维护的集合。
在这棵 LCT 上,每一个点的答案,就等于它不停向上走经过的虚边的边数 + 1(因为虚边代表颜色出现了变化)。
对于操作一:在 的过程中,就是不断地改变虚边和实边,当一条虚边变成实边时,这条虚边连接的子树所有节点的答案 - 1,因为它们经过的虚边少了一条,同理实边变虚边时,子树的答案 + 1。这个维护可以用线段树 + 原树的 dfs 序,splay 上的虚边和实边并不是原树的边,因此要维护一下每棵 splay 的最左端点,方便找到一棵子树的根。
对于操作二:设
表示
到根这个路径上的答案,稍加思考仔细思考可以得到操作二的答案等于
对于操作三:用线段树维护一下区间最值。
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 4e5 + 10;
typedef long long ll;
#define pii pair<int,int>
#define fir first
#define sec second
int n,m;
int op,id,x,y,u,v,w;
int p[maxn][21],dep[maxn],st[maxn],ed[maxn],cnt,dfn[maxn];
map<int,int> mp[maxn];
inline int read(){
int w=0,q=0; char c=getchar(); while((c<'0'||c>'9') && c!='-') c=getchar();
if(c=='-') q=1,c=getchar(); while (c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w;
}
struct Graph {
int head[maxn],nxt[maxn << 1],to[maxn << 1],cnt;
void init() {
cnt = 0;
memset(head,-1,sizeof head);
}
void add(int u,int v) {
nxt[cnt] = head[u];
to[cnt] = v;
head[u] = cnt++;
nxt[cnt] = head[v];
to[cnt] = u;
head[v] = cnt++;
}
}G;
struct seg_tree {
#define lson rt << 1,l,mid
#define rson rt << 1 | 1,mid + 1,r
int val[maxn << 2],add[maxn << 2];
void build(int rt,int l,int r) {
add[rt] = 0;
if (l == r) {
val[rt] = dep[dfn[l]];
return ;
}
int mid = l + r >> 1;
build(lson); build(rson);
val[rt] = max(val[rt << 1],val[rt << 1 | 1]);
}
void pushdown(int rt) {
if (!add[rt]) return ;
int ls = rt << 1, rs = rt << 1 | 1;
add[ls] += add[rt], val[ls] += add[rt];
add[rs] += add[rt], val[rs] += add[rt];
add[rt] = 0;
}
void update(int L,int R,int v,int rt,int l,int r) { //区间更新
if (L <= l && r <= R) {
val[rt] += v;
add[rt] += v;
return ;
}
pushdown(rt);
int mid = l + r >> 1;
if (L <= mid) update(L,R,v,lson);
if (mid + 1 <= R) update(L,R,v,rson);
val[rt] = max(val[rt << 1],val[rt << 1 | 1]);
}
int qry(int L,int R,int rt,int l,int r) {
if (L > R || L <= 0 || R <= 0) return 0;
if (L <= l && r <= R) return val[rt];
pushdown(rt);
int mid = l + r >> 1;
int ans = 0;
if (L <= mid) ans = max(ans,qry(L,R,lson));
if (mid + 1 <= R) ans = max(ans,qry(L,R,rson));
return ans;
}
}seg;
struct LCT { //用splay维护原森林的连通,用到了splay的操作以及数组
#define ls ch[x][0]
#define rs ch[x][1]
#define inf 0x3f3f3f3f
int ch[maxn][2]; //ch[u][0] 表示 左二子,ch[u][1] 表示右儿子
int f[maxn]; //当前节点的父节点
int lt[maxn]; //维护一条链的左端点
inline bool get(int x) {
return ch[f[x]][1] == x;
}
inline bool isroot(int x) {
return (ch[f[x]][0] != x) && (ch[f[x]][1] != x);
}
void pushup(int x) {
lt[x] = x;
if (ls) lt[x] = lt[ls];
}
inline void rotate(int x) { //旋转操作,根据 x 在 f[x] 的哪一侧进行左旋和右旋
int old = f[x], oldf = f[old];
int whichx = get(x);
if(!isroot(old)) ch[oldf][ch[oldf][1] == old] = x; //如果 old 不是根节点,就要修改 oldf 的子节点信息
ch[old][whichx] = ch[x][whichx ^ 1];
ch[x][whichx ^ 1] = old;
f[ch[old][whichx]] = old;
f[old] = x; f[x] = oldf;
pushup(old); pushup(x);
}
inline void splay(int x) { //将 x 旋到所在 splay 的根
for(int fa = f[x]; !isroot(x); rotate(x), fa = f[x]) { //再把x翻上来
if(!isroot(fa)) //如果fa非根,且x 和 fa是同一侧,那么先翻转fa,否则先翻转x
rotate((get(x) == get(fa)) ? fa : x);
}
}
inline void access(int x) { //access操作将x 到 根路径上的边修改为重边
int lst = 0;
while(x > 0) {
splay(x);
if (ch[x][1]) {
int rt = lt[ch[x][1]];
seg.update(st[rt],ed[rt],1,1,1,n);
}
ch[x][1] = lst;
if (lst) {
int rt = lt[lst];
seg.update(st[rt],ed[rt],-1,1,1,n);
}
pushup(x);
lst = x; x = f[x];
}
}
}lct;
void dfs(int u,int fa) {
dep[u] = dep[fa] + 1;
st[u] = ++cnt; dfn[cnt] = u;
for (int i = 1; i <= 20; i++)
p[u][i] = p[p[u][i - 1]][i - 1];
for (int i = G.head[u]; i + 1; i = G.nxt[i]) {
int v = G.to[i];
if (v == fa) continue;
p[v][0] = u;
lct.f[v] = u; //lct建树
dfs(v,u);
}
ed[u] = cnt;
}
int getlca(int x,int y) {
if (dep[x] < dep[y]) swap(x,y);
for (int i = 20; i >= 0; i--) {
if (dep[p[x][i]] >= dep[y])
x = p[x][i];
}
if (x == y) return x;
for (int i = 20; i >= 0; i--) {
if (p[x][i] != p[y][i]) {
x = p[x][i], y = p[y][i];
}
}
return p[x][0];
}
int main() {
G.init();
n = read(); m = read();
for (int i = 1; i <= n; i++)
lct.lt[i] = i;
for (int i = 1; i < n; i++) {
u = read(); v = read();
G.add(u,v);
}
dfs(1,0);
seg.build(1,1,n);
while (m--) {
op = read();
if (op == 1) {
x = read();
lct.access(x);
} else if (op == 2) {
x = read(); y = read();
int lca = getlca(x,y);
int tx = seg.qry(st[x],st[x],1,1,n), ty = seg.qry(st[y],st[y],1,1,n), tz = seg.qry(st[lca],st[lca],1,1,n);
int ans = tx + ty - 2 * tz + 1;
printf("%d\n",ans);
} else {
x = read();
printf("%d\n",seg.qry(st[x],ed[x],1,1,n));
}
}
return 0;
}