HDU - 5909 Tree Cutting(树形dp + FWT)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5909

题目大意:题目定义一棵树的价值为树上所有结点权值的异或和。现在给你一棵带权树,树上点的权值都在范围[0,m-1]内,问你这个树有多少子树的价值为k,k=[0,1,2,3,...,m-1]。

题目思路:考虑做树形dp,dp[u][j]表示以 u 为根节点的树中异或和为 j 的子树有多少个,那么就可以得出如下的状态转移方程

dp[u][j]=dp[u][j]+dp[u][j\oplus k]*dp[son[u]][k]

但如果暴力转移这个方程的话,复杂度是O(m^2)的。

现在我们令f[j]=dp[u][j\oplus k]*dp[son[u]][k],这个式子满足FWT,我们就可以用FWT来加速这个计算方程,使得状态转移方程的复杂度降到O(m*logm)。做完dp最后再统计下答案即可,整体复杂度为O(n*m*logm)。

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define debug(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef unsigned long long ull;
typedef unsigned int UI;
typedef long long ll;
typedef pair<int,int>pii;
const int MX = 2000 + 7;
const int MOD = 1e9 + 7;
const int inv2 = 500000004;

void FWT_xor(ll *a,int N,int opt){
    for(int i=1;i<N;i<<=1){
        for(int p=i<<1,j=0;j<N;j+=p){
            for(int k=0;k<i;++k){
                int X=a[j+k],Y=a[i+j+k];
                a[j+k]=(X+Y)%MOD;a[i+j+k]=(X+MOD-Y)%MOD;
                if(opt==-1){
                    a[j+k]=1ll*a[j+k]*inv2%MOD;
                    a[i+j+k]=1ll*a[i+j+k]*inv2%MOD;
                }
            }
        }
    }
}

int n,m,_;
int val[MX];
ll dp[MX][MX],va[MX],vb[MX],res[MX];
vector<int>E[MX];
void dfs(int u,int fa){
    dp[u][val[u]] = 1;
    for(auto v:E[u]){
        if(v == fa) continue;
        dfs(v,u);
        for(int i = 0;i < m;i++){
            va[i] = dp[u][i];
            vb[i] = dp[v][i];
        }
        FWT_xor(va,m,1);FWT_xor(vb,m,1);
        for(int i = 0;i < m;i++) va[i] = (va[i] * vb[i]) % MOD;
        FWT_xor(va,m,-1);
        for(int i = 0;i < m;i++) dp[u][i] = (dp[u][i] + va[i]) % MOD;
    }
    for(int i = 0;i < m;i++) res[i] = (res[i] + dp[u][i]) % MOD;
}

int main(){
    //FIN;
    for(scanf("%d",&_);_;_--){
        clr(dp);clr(res);
        scanf("%d%d",&n,&m);
        for(int i = 1;i <= n;i++){
            scanf("%d",&val[i]);
            E[i].clear();
        }
        for(int i = 1;i < n;i++){
            int u, v;
            scanf("%d%d",&u,&v);
            E[u].pb(v);E[v].pb(u);
        }
        dfs(1,0);
        for(int i = 0;i < m;i++)
            printf("%lld%c",res[i],i==m-1?'\n':' ');
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Lee_w_j__/article/details/82055927