codeforces1467 E. Distinctive Roots in a Tree(树上差分)

E. Distinctive Roots in a Tree

树上差分

  • 如果当前节点u的某一棵子树中的某个节点的值和当前节点相同,那么除了当前节点这一棵子树节点,其他节点(其他子树以及u上面的节点)一定不满足要求。

  • 如果当前节点子树之外的节点(u上面的节点)与当前节点值相同,那么当前子树节点不满足要求。

如何知道当前子树中的节点是否与当前节点相同?
dfs过程中记录进入该子树之前某值的个数与出子树后该值的个数进行比较,如果比之前多,说明子树中存在该值。

如何知道当前节点所有子树之外的节点是否存在与当前值相同的节点?
如果不存在,说明当前节点所有子树出现该值个数和应该与总个数相同

对于不能作为答案的节点标记一下即可,由于只有子树操作考虑dfs序,区间修改单点查询差分即可。

#define IO ios::sync_with_stdio(false);cin.tie();cout.tie(0)
#pragma GCC optimize(2)
#include<set>
#include<map>
#include<cmath>
#include<stack>
#include<queue>
#include<random>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N=200010,mod=1e9+7;
int h[N],e[2*N],ne[2*N],idx;
int a[N],cnt[N],num[N],s[N],n;
map<int,int> mp;
int find(int x)
{
    
    
    if(!mp.count(x)) mp[x]=++idx;
    return mp[x];
}
void add(int a,int b)
{
    
    
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void update(int l,int r,int x)
{
    
    
    s[l]+=x,s[r+1]-=x;
}
int dfn[N],timestamp,sz[N];
void dfs(int u,int fa)
{
    
    
    dfn[u]=++timestamp;sz[u]=1;
    int now=cnt[a[u]];// 差分统计u子树出现a[u]的次数
    cnt[a[u]]++;
    for(int i=h[u];i!=-1;i=ne[i])
    {
    
    
        int j=e[i];
        if(j==fa) continue;
        int pre=cnt[a[u]];//进入子树前
        dfs(j,u);
        sz[u]+=sz[j];
        if(cnt[a[u]]>pre) //说明j子树出现了a[u]
            update(1,n,1),update(dfn[j],dfn[j]+sz[j]-1,-1);
    }
    if(cnt[a[u]]-now!=num[a[u]])//差分统计u子树出现a[u]的次数 不等于总个数
        update(dfn[u],dfn[u]+sz[u]-1,1);//1表示不能作为答案
    
}
int main()
{
    
    
    IO;
    int T=1;
    //cin>>T;
    while(T--)
    {
    
    
        memset(h,-1,sizeof h);
        cin>>n;
        for(int i=1;i<=n;i++) 
        {
    
    
            cin>>a[i];
            a[i]=find(a[i]);
            num[a[i]]++;//总个数
        }
        idx=0;
        for(int i=1;i<n;i++)
        {
    
    
            int a,b;
            cin>>a>>b;
            add(a,b),add(b,a);
        }
        dfs(1,-1);
        int res=0;
        for(int i=1;i<=n;i++)
        {
    
    
            s[i]+=s[i-1];
            if(!s[i]) res++;
        }
        cout<<res<<'\n';
    }
    return 0;
}

要加油哦~

猜你喜欢

转载自blog.csdn.net/Fighting_Peter/article/details/112619393