版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/liufengwei1/article/details/86595767
因为要求得是非空子树的异或和方案数,那么f[u][0...m-1]表示以u为根节点必取u这个节点的所有值的方案数,这里的子树表示的不是满子树,而是选择他的根节点和部分子树中的节点来异或。那么计算出u的子节点的f[v][0...m-1]后,我们对f[u][]和f[v][]跑一次fwt,就知道假如选择u这个连通块中的方案和v这个连通块中的方案异或可以得到多少新的方案。加到f[u][]中去,因为原来的就是不选择v这个点及其子树的方案。
由于要求所有非空子树的异或和的方案数,那么每个点都可以选或者不选。及每个f[u][0...m-1]都要加到ans[0...m-1]中去
#include<bits/stdc++.h>
#define maxl 1030
const int mod=1e9+7;
int n,m,cnt;
long long rev;
int a[maxl<<1],b[maxl<<1],c[maxl<<1];
int val[maxl],ehead[maxl],ans[maxl];
int f[maxl][maxl];
struct ed
{
int to,nxt;
}e[maxl<<1];
long long qp(long long a,long long b)
{
long long ans=1,cnt=a;
while(b)
{
if(b&1)
ans=(ans*cnt)%mod;
cnt=(cnt*cnt)%mod;
b>>=1;
}
return ans;
}
inline void add(int u,int v)
{
e[++cnt].to=v;e[cnt].nxt=ehead[u];ehead[u]=cnt;
}
inline void prework()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
for(int j=0;j<m;j++)
f[i][j]=0;
ehead[i]=0;
}
for(int i=0;i<m;i++)
ans[i]=0;
for(int i=1;i<=n;i++)
scanf("%d",&val[i]),f[i][val[i]]=1;
int u,v;
cnt=0;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
}
void fwt(int a[],int n)
{
for(int d=1;d<n;d<<=1)
for(int m=d<<1,i=0;i<n;i+=m)
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
//a[i+j]=(x+y)%mod,a[i+j+d]=(x-y+mod)%mod;
a[i+j]=(x+y)%mod,a[i+j+d]=(x-y+mod)%mod;
//and:a[i+j]=x+y;
//or:a[i+j+d]=x+y;
}
}
void ufwt(int a[],int n)
{
for(int d=1;d<n;d<<=1)
for(int m=d<<1,i=0;i<n;i+=m)
for(int j=0;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
//a[i+j]=1LL*(x+y)*rev%mod,a[i+j+d]=(1LL*(x-y)*rev%mod+mod)%mod;
a[i+j]=(1LL*(x+y)*rev)%mod,a[i+j+d]=(1LL*((x-y+mod)%mod)*rev)%mod;
//and:a[i+j]=x-y;
//or:a[i+j+d]=y-x;
//rev是2在mod下的逆元
}
}
void solve(int a[],int b[],int len)
{
fwt(a,len);fwt(b,len);
for(int i=0;i<len;i++)
a[i]=1LL*a[i]*b[i]%mod;
ufwt(a,len);
}
inline void dfs(int u,int fa)
{
int v;
for(int i=ehead[u];i;i=e[i].nxt)
{
v=e[i].to;
if(v==fa)
continue;
dfs(v,u);
for(int j=0;j<m;j++)
a[j]=f[u][j],b[j]=f[v][j];
solve(a,b,m);
for(int j=0;j<m;j++)
f[u][j]=(f[u][j]+a[j])%mod;
}
for(int i=0;i<m;i++)
ans[i]=(ans[i]+f[u][i])%mod;
}
inline void mainwork()
{
dfs(1,0);
}
inline void print()
{
for(int i=0;i<m;i++)
printf("%d%c",ans[i],(i==m-1)?'\n':' ');
}
int main()
{
int t;
scanf("%d",&t);rev=qp(2,mod-2);
for(int i=1;i<=t;i++)
{
prework();
mainwork();
print();
}
return 0;
}