ABC163 F - path pass i(树形dp)

题意:

在这里插入图片描述

解法:

直接计算经过颜色k的路径很难算,考虑计算不经过颜色k的路径,然后用总路径去除.

如果只需要算一种颜色,那么我们可以直接以颜色k的点作为图的分割点,
然后dfs出每个连通块的大小,设某个连通块的大小为t,那么对答案的贡献为t*(t-1)/2+t=t*(t+1)/2.
复杂度是O(n).

但是这题要求我们计算出所有颜色的答案,如果对每种颜色都暴力做一遍,复杂度是O(n^2),显然不行.

以下我们假设树根为1,即令其为有根树.

考虑颜色k分割之后的形态:

在这里插入图片描述

图中蓝点是颜色为k的点,假设我们要计算颜色k的答案,那么只需要计算出每个颜色k下面的连通块大小.

在这里插入图片描述

考虑计算每个点作为连通块父节点(即图中的蓝色点)的答案.
设点x的颜色为k,那么只需要找到x下面每条链中,第一个颜色为k的点,统计他们的size和,
设x的子节点为v,x往下每条链中第一个遇到的颜色为k的点的sz和为del[k],
设t=sz[v]-del[k],那么点x对ans[k]的贡献为t*(t+1)/2.

找某个点向下第一个颜色为k的节点,可以用类似栈的玩意实现,
我的做法是定义一个cnt[],向下搜索的时cnt[a[x]]++,回溯的时候cnt[a[x]]--,
用cnt[a[x]]是否为0来判断点x是否为连通块父节点下,第一个颜色为a[x]的点.

还要注意的是点1的父节点可以作为任意颜色的连通块的父节点,最后记得也计算一边.

思路很简单,具体实现稍微麻烦一点,
因为要计算每个点作为连通块父节点的答案,因此向下递归的时候需要清空一些地方,
这些地方需要提前存下来,回溯的时候在恢复,
具体实现见代码.
x,v,del图示:

在这里插入图片描述

code:

#include <bits/stdc++.h>
#define int long long
#define PI pair<int,int>
using namespace std;
const int maxm=2e6+5;
vector<int>g[maxm];
int ans[maxm];
int cnt[maxm];
int del[maxm];
int sz[maxm];
int a[maxm];
int n;
void dfs1(int x,int fa){
    
    
    sz[x]=1;
    for(int v:g[x]){
    
    
        if(v==fa)continue;
        dfs1(v,x);
        sz[x]+=sz[v];
    }
}
void dfs(int x,int fa){
    
    
    //保存父节点状态
    cnt[a[x]]++;
    int temp=cnt[a[x]];
    int temp2=del[a[x]];
    //
    for(int v:g[x]){
    
    
        if(v==fa)continue;
        //计算以x为连通块父节点,子树v中的答案
        cnt[a[x]]=0;
        del[a[x]]=0;
        dfs(v,x);
        int t=sz[v]-del[a[x]];
        ans[a[x]]+=t*(t+1)/2;
    }
    //恢复父节点状态
    cnt[a[x]]=temp;
    del[a[x]]=temp2;
    cnt[a[x]]--;
    if(cnt[a[x]]==0)del[a[x]]+=sz[x];//x是遇到的第一个a[x]点
}
void solve(){
    
    
    cin>>n;
    for(int i=1;i<=n;i++){
    
    
        cin>>a[i];
    }
    for(int i=1;i<n;i++){
    
    
        int a,b;cin>>a>>b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    dfs1(1,0);
    dfs(1,0);
    for(int i=1;i<=n;i++){
    
    
        int t=n-del[i];
        ans[i]+=t*(t+1)/2;
    }
    int tot=n*(n+1)/2;
    for(int i=1;i<=n;i++){
    
    
        ans[i]=tot-ans[i];
    }
    for(int i=1;i<=n;i++){
    
    
        cout<<ans[i]<<endl;
    }
}
signed main(){
    
    
    ios::sync_with_stdio(0);
    solve();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_44178736/article/details/115016068
今日推荐