SP10707 COT2 - Count on a tree II(求树中两点路径的不同颜色数量+树上莫队)

https://www.luogu.com.cn/problem/SP10707


树上莫队板子。

罗老师的思路:


1、把树的结点用欧拉序转为一维数组
  用DFS遍历树的结点,有两种遍历方式,得到两种欧拉序:
  (1)在每个结点第一次进和最后一次出都加进序列;
  (2)每遇到一个结点就把它加进序列。
  这里用第(1)种形式的欧拉序。下图的例子,欧拉序:{1, 2, 2, 3, 5, 5, 6, 6, 7, 7, 3, 4, 8, 8, 4, 1}。

图8 一棵树

  (u, v)上的路径有哪些结点?首先计算出u、v的lca(u, v)(最近公共祖先),然后讨论两种情况:
  (1)lca(u, v) = u或lca(u, v) = v,即u在v的子树中,或者v在u的子树中。例如u = 1, v = 6,区间是{1, 2, 2, 3, 5, 5, 6},出现2次的结点{2, 5}不属于这条路径,因为它进来了又出去了。只出现一次的结点属于这条路径,即{1, 3, 6}。
  (2)lca(u, v) ≠ u且lca(u, v) ≠ v,即u和v都不在对方的子树上。此时u、v之间的路径需要通过它们的lca,但是lca没有出现在u和v的欧拉序区间内,需要添上。例如u = 5,v = 8,区间是{5, 6, 6, 7, 7, 3, 4, 8},去掉出现2次的结点{6, 7},剩下{5, 3, 4, 8},再加上它们的lca = 1,得路径{5, 3, 4, 8, 1}。再例如u = 5,v = 7,区间是{5, 6, 6, 7},去掉6,剩下{5, 7},再加上它们的lca = 3,得路径{5, 7, 3}。
2、本题的求解步骤
  (1)求树的欧拉序,得到一维数组;求任意两个点的lca。编码时用树链剖分(做两次DFS)求欧拉序和lca。
  (2)把题目的查询(u, v)看成一维数组上的查询。题目要求查询(u, v)内不同的颜色,首先查区间(u, v)内只出现1次的结点,并加上u、v的lca,得到路径上的所有结点,然后在这些结点中统计只出现1次的数字。
  (3)用莫队算法,离线处理所有的查询,然后一起输出。注意分块时,本题的规模是2n,因为每个结点在欧拉序中出现2次;另外每个结点的颜色数值很大,需要离散化。


之前做过dfs序维护子树的权值,通过dfs序转化成一个子树连续序列,从而区间修改对应树上该子树的连续序列范围,从而用线段树维护。

题目链接:https://vjudge.net/problem/HDU-3974/origin(dfs序上线段树)

这题是dfs中预处理欧拉序。然后用欧拉序来讨论在不在子树中的情况下,得到的序列的关系。转到成线性然后用莫队nsqrt(n)维护。

一个比较好的办法:

如何忽略掉区间内出现了两次的点,多记录一个vis[x],表示x这个树节点有没有被加入,每次处理的时候如果vis[x]=0则需要添加节点;如果vis[x]=1则需要删除节点,每次处理之后都对vis[x]异或1就可以了

#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=4e5+100;
typedef int LL;
struct Query{
    LL l,r,ll,rr,lca;
    LL id;
}q[100010];
bool vis[maxn];
LL a[maxn],b[maxn];
LL cnt[maxn];
LL answer[maxn];
LL sum=0;
///莫队部分
bool cmp(Query A,Query B){
    if(A.ll!=B.ll){
        return A.ll<B.ll;
    }
    if(A.ll&1) return A.rr>B.rr;
    else return A.rr<B.rr;
}
void add(LL x){
    cnt[a[x]]++;
    if(cnt[a[x]]==1) sum++;
}
void del(LL x){
    cnt[a[x]]--;
    if(cnt[a[x]]==0) sum--;
}
void cal(LL x){
    ///每个节点都有一个颜色
    (!vis[x])?add(x):del(x);
    vis[x]^=1;
}
///树链剖分部分(求LCA)
LL times=0;
LL siz[maxn],top[maxn],dep[maxn],fa[maxn],son[maxn];
vector<LL>g[maxn];
LL numid[maxn*2],st[maxn],ed[maxn];
void predfs(LL u,LL father){
    siz[u]=1;dep[u]=dep[father]+1;
    fa[u]=father;
    st[u]=++times;
    numid[times]=u;
    for(LL i=0;i<g[u].size();i++){
        LL v=g[u][i];
        if(v==father) continue;
        predfs(v,u);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]){
            son[u]=v;
        }
    }
    ed[u]=++times;numid[times]=u;///欧拉序
}
void dfs(LL u,LL topx){
     top[u]=topx;
     if(!son[u]) return;
     dfs(son[u],topx);
     for(LL i=0;i<g[u].size();i++){
        LL v=g[u][i];
        if(v==fa[u]||v==son[u]) continue;
        dfs(v,v);
     }
}
LL getLCA(LL u,LL v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    return u;
}
int main(void)
{
  cin.tie(0);std::ios::sync_with_stdio(false);
  LL n,m;cin>>n>>m;
  for(LL i=1;i<=n;i++){
    cin>>a[i];b[i]=a[i];
  }
  sort(b+1,b+1+n);
  LL siz=unique(b+1,b+1+n)-b-1;
  for(LL i=1;i<=n;i++){
    a[i]=lower_bound(b+1,b+1+siz,a[i])-b;///离散化
  }
  for(LL i=1;i<n;i++){
    LL u,v;cin>>u>>v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  predfs(1,0);
  dfs(1,1);
  LL block=sqrt(n*2);
  for(LL i=1;i<=m;i++){
    LL u,v;cin>>u>>v;
    if(st[u]>st[v]) swap(u,v);///先被访问的放前面
    LL LCA=getLCA(u,v);
    ///u和v在同一个子树内
    if(LCA==u||LCA==v){
        q[i].id=i;
        q[i].l=st[u];q[i].r=st[v];
        q[i].ll=(st[u]-1)/block+1;///l端点分到的块的编号
        q[i].rr=(st[v]-1)/block+1;///r端点分到的块的编号
        q[i].lca=0;
    }
    ///u和v不在同一棵子树内
    else{
        q[i].id=i;
        q[i].l=ed[u];q[i].r=st[v];
        q[i].ll=(ed[u]-1)/block+1;q[i].rr=(st[v]-1)/block+1;
        q[i].lca=LCA;///最后要加上LCA的贡献
    }
  }
  LL L=1;LL R=0;
  sort(q+1,q+1+m,cmp);
  for(LL i=1;i<=m;i++){
     while(L<q[i].l) cal(numid[L++]);///该树的节点
     while(L>q[i].l) cal(numid[--L]);
     while(R<q[i].r) cal(numid[++R]);
     while(R>q[i].r) cal(numid[R--]);
     if(q[i].lca){
        cal(q[i].lca);
     }
     answer[q[i].id]=sum;
     if(q[i].lca){
        cal(q[i].lca);
     }
  }
  for(LL i=1;i<=m;i++){
    cout<<answer[i]<<endl;
  }
return 0;
}

猜你喜欢

转载自blog.csdn.net/zstuyyyyccccbbbb/article/details/110251080
今日推荐