CF1467E. Distinctive Roots in a Tree

题意:

一棵有点权的无根树,定义特殊点:该点到树上任意一点路径上的点权不重复。问特殊点的数量。

题解:

因为树上两点路径的唯一性,对于两个点权相同的点 u 1 , u 2 u_1,u_2 u1,u2,显然只有这两点之间的点可以作为特殊点。因为两端的点,比如在 u 1 u_1 u1一端的点,它到 u 2 u2 u2的路径必然经过 u 1 u1 u1,就会导致点权重复。

任取一个作为根节点。定义特殊值:只有特殊点的特殊值为0。我们考虑当前节点 u u u。树会被分成三个部分,点集 A = { x ∣ d f n [ x ] < d f n [ u ] } A=\{x|dfn[x]<dfn[u]\} A={ xdfn[x]<dfn[u]},也就是在 u u u之前已经遍历的点;点集 B = { x ∣ x 是 u 子 树 上 的 节 点 } B=\{x|x是u子树上的节点\} B={ xxu};点集 C = { x ∣ x ∉ A ∨ B } C=\{x|x\notin A\lor B\} C={ xx/AB},也就是还没遍历到的、不是 u u u子树的节点。

  1. ∃ x ∈ A ∨ C , a [ x ] = a [ u ] \exists x\in A\lor C ,a[x]=a[u] xAC,a[x]=a[u],那么 B B B中所有点的都不是特殊点。将 u u u子树上所有点特殊值加1。
  2. ∀ x ∈ A ∨ C , a [ x ] ≠ a [ u ] ; ∃ y ∈ B ∩ y ≠ u , a [ y ] = a [ u ] \forall x\in A\lor C,a[x]\neq a[u];\exists y\in B \cap y\neq u,a[y]=a[u] xAC,a[x]=a[u];yBy=u,a[y]=a[u],那么的 y y y所在的儿子子树中存在特殊点,且 A 、 C A、C AC u u u的其他所有儿子子树的所有点都不是特殊点。将树上所有点特殊值加1,再将 y y y所在这棵儿子子树的特殊值减1。因为在遍历到 y y y的时候,根据1. ,已经对 y y y的子树特殊值加1了,就实现了只有 u u u y y y中间的点特殊值没有改变。

观察发现,每一次的特殊值的改变都是以子树为单位进行的。树上差分

对于1. 的实现,首先预处理出所有点权的出现次数 m p 1 mp1 mp1,用 m p 2 mp2 mp2在遍历的时候记录点权出现的次数。遍历到 u u u的时候,可以记录一下 t m p = m p 2 [ a [ u ] ] tmp=mp2[a[u]] tmp=mp2[a[u]],代表 A A A中点权为 a [ u ] a[u] a[u]的点的数量;遍历完 u u u的子树后,此时 m p 2 [ a [ u ] ] mp2[a[u]] mp2[a[u]]表示 A ∨ B A\lor B AB a [ u ] a[u] a[u]的出现次数;那么 m p 2 [ a [ u ] ] − t m p mp2[a[u]]-tmp mp2[a[u]]tmp代表着 B B B a [ u ] a[u] a[u]的出现次数。如果 m p 2 [ a [ u ] ] − t m p < m p 1 [ a [ u ] ] mp2[a[u]]-tmp<mp1[a[u]] mp2[a[u]]tmp<mp1[a[u]],说明 A ∨ C A\lor C AC中存在点权为 a [ u ] a[u] a[u]的点,执行1. 。

AC代码:

#include <bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define lep(i,a,b) for(int i=(a);i>=(b);i--) 
#define sci(x) scanf("%d",&(x))
#define scl(x) scanf("%lld",&(x))
#define scs(x) scanf("%s",(x))
#define pri(x) printf("%d\n",(x))
#define prl(x) printf("%lld\n",(x))
#define prs(x) printf("%s\n",(x))
#define pii pair<int,int>
#define pll pair<long long,long long>
#define All(x) x.begin(),x.end() 
#define ms(a,b) memset(a,b,sizeof(a)) 
#define INF 0x3f3f3f3f
#define INFF 0x3f3f3f3f3f3f3f3f
#define multi int T;scanf("%d",&T);while(T--) 
using namespace std;
typedef long long ll;
typedef double db;
const int N=2e5+5;
const int mod=10007;
const db eps=1e-6;                                                                            
const db pi=acos(-1.0);
int n,cnt=0,a[N],dfn[N],sz[N],num[N];
map<int,int>mp1,mp2;
vector<int>tr[N];
void upd(int st,int ed,int val){
    num[st]+=val;
    num[ed+1]-=val;
}
void dfs(int u,int fa=-1){
    dfn[u]=++cnt;
    sz[u]=1;
    int tmp=mp2[a[u]];
    ++mp2[a[u]];
    for(auto v:tr[u]){
        if(v==fa) continue;
        int cur=mp2[a[u]];
        dfs(v,u);
        if(mp2[a[u]]!=cur){
            upd(dfn[v],dfn[v]+sz[v]-1,-1);
            upd(1,n,1);
        }
        sz[u]+=sz[v];
    }
    tmp=mp2[a[u]]-tmp;
    if(tmp<mp1[a[u]]) upd(dfn[u],dfn[u]+sz[u]-1,1);
}
int main(){
    #ifndef ONLINE_JUDGE
    freopen("D:\\work\\data.in","r",stdin);
    #endif
    cin>>n;
    rep(i,1,n){
        cin>>a[i];
        ++mp1[a[i]];
    }
    rep(i,1,n-1){
        int x,y;
        cin>>x>>y;
        tr[x].push_back(y);
        tr[y].push_back(x);
    }
    dfs(1);
    int ans=0;
    rep(i,1,n){
        num[i]+=num[i-1];
        if(!num[i]) ++ans;
    }
    cout<<ans<<endl;
}

猜你喜欢

转载自blog.csdn.net/Luowaterbi/article/details/112666065