链接
题解
假设给定一个序列,怎么计算题目说的那个式子呢?
上式中 表示 的第 个二进制位上的数( 或 )
交换一下求和顺序,可以得到
枚举 ,就成了在树上统计这两位都不相同的点对个数,“长链剖分+后缀和”老套路即可
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 100010
#define maxe 200010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(auto p=__.head[_];p;p=__.next[p])
}G;
int n, pos[maxn], tot, a[maxn], K, b[maxn], k1, k2, c[maxn][4];
ull ans[maxn];
struct Longest_Chain_Decomposition
{
int len[maxn], son[maxn], depth[maxn], istop[maxn];
void dfs(Graph& G, int u, int fa)
{
son[u]=0;
len[u]=1;
depth[u]=depth[fa]+1;
istop[u]=false;
forp(u,G)
{
int v(G.to[p]); if(v==fa)continue;
dfs(G,v,u);
if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
}
forp(u,G)
{
int v(G.to[p]); if(v==fa)continue;
if(v!=son[u])istop[v]=true;
}
}
void run(Graph& G, int root)
{
tot=0;
depth[0]=0, dfs(G,root,0);
istop[root]=true;
}
}lcd;
int getpre(int u, int L, int k)
{
if(L+1<lcd.len[u])return c[pos[u]+0][k] - c[pos[u]+(L+1)][k];
return c[pos[u]+0][k];
}
void dfs(int u, int fa)
{
int i, k;
if(lcd.istop[u])
{
pos[u] = tot;
tot += lcd.len[u];
}
c[pos[u]][b[u]]++;
if(lcd.son[u])
{
auto v(lcd.son[u]);
pos[v] = pos[u]+1;
dfs(v,u);
rep(k,0,3)c[pos[u]][k]+=c[pos[v]][k];
}
forp(u,G)
{
int v=G.to[p];
if(v==fa or v==lcd.son[u])continue;
dfs(v,u);
rep(k,0,3)c[pos[u]][k]+=c[pos[v]][k];
rep(i,0,lcd.len[v]-1)
rep(k,0,3)c[pos[u]+(i+1)][k] += c[pos[v]+i][k];
}
rep(k,0,1)ans[u]+=(ull)getpre(u,K,k)*getpre(u,K,3^k)<<(k1+k2);
}
int main()
{
int i, fa, j;
n = read(), K=read();
rep(i,1,n)a[i]=read();
rep(i,2,n)
{
fa=read();
G.adde(fa,i);
}
lcd.run(G,1);
rep(k1,0,29)rep(k2,k1+1,29)
{
rep(i,1,n)
{
b[i]=0;
if(a[i]&(1<<k1))b[i]|=1;
if(a[i]&(1<<k2))b[i]|=2;
}
rep(i,0,n)rep(j,0,3)c[i][j]=0;
tot=0;
dfs(1,0);
}
rep(i,1,n)ans[i]<<=1;
rep(k1,0,29)
{
k2=k1;
rep(i,1,n)
{
b[i]=0;
if(a[i]&(1<<k1))b[i]|=1;
if(a[i]&(1<<k2))b[i]|=2;
}
rep(i,0,n)rep(j,0,3)c[i][j]=0;
tot=0;
dfs(1,0);
}
rep(i,1,n)printf("%llu\n",ans[i]);
return 0;
}