ICPC2019银川区域赛 E.XOR Tree

链接

点击跳转

题解

假设给定一个序列,怎么计算题目说的那个式子呢?

i = 1 n j = i + 1 n ( a i a j ) 2 = i = 1 n j = i + 1 n ( ( a i , 0 a j , 0 ) 2 0 + ( a i , 1 a j , 1 ) 2 1 + + ( a i , 29 a j , 29 ) 2 29 ) 2 = i = 1 n j = i + 1 n k 1 = 0 29 k 2 = 0 29 ( a i , k 1 a j , k 1 ) ( a i , k 2 a j , k 2 ) 2 k 1 + k 2 \sum_{i=1}^n\sum_{j=i+1}^n (a_i \oplus a_j)^2 \\ = \sum_{i=1}^n \sum_{j=i+1}^n ((a_{i,0}\oplus a_{j,0})2^0 + (a_{i,1}\oplus a_{j,1})2^1 + \dots + (a_{i,29}\oplus a_{j,29})2^{29} )^2 \\ = \sum_{i=1}^n \sum_{j=i+1}^n \sum_{k_1=0}^{29} \sum_{k_2=0}^{29} (a_{i,k_1} \oplus a_{j,k_1}) (a_{i,k_2} \oplus a_{j,k_2})2^{k_1+k_2}

上式中 a i , k a_i,k 表示 a i a_i 的第 k k 个二进制位上的数( 0 0 1 1 )

交换一下求和顺序,可以得到

k 1 = 0 29 k 2 = 0 29 2 k 1 + k 2 i = 1 n j = i + 1 n [ a i , k 1 a j , k 1 ] [ a i , k 2 a j , k 2 ] \sum_{k_1=0}^{29} \sum_{k_2=0}^{29} 2^{k_1+k_2} \sum_{i=1}^n\sum_{j=i+1}^n [a_{i,k_1}\neq a_{j,k_1}][a_{i,k_2}\neq a_{j,k_2}]

枚举 k 1 , k 2 k1,k2 ,就成了在树上统计这两位都不相同的点对个数,“长链剖分+后缀和”老套路即可

代码

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 100010
#define maxe 200010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
    ll c, f(1);
    for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
    for(;isdigit(c);c=getchar())x=x*10+c-0x30;
    return f*x;
}
struct Graph
{
    int etot, head[maxn], to[maxe], next[maxe], w[maxe];
    void clear(int N)
    {
        for(int i=1;i<=N;i++)head[i]=0;
        etot=0;
    }
    void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
    #define forp(_,__) for(auto p=__.head[_];p;p=__.next[p])
}G;
int n, pos[maxn], tot, a[maxn], K, b[maxn], k1, k2, c[maxn][4];
ull ans[maxn];
struct Longest_Chain_Decomposition
{
    int len[maxn], son[maxn], depth[maxn], istop[maxn];
    void dfs(Graph& G, int u, int fa)
    {
        son[u]=0;
        len[u]=1;
        depth[u]=depth[fa]+1;
        istop[u]=false;
        forp(u,G)
        {
            int v(G.to[p]); if(v==fa)continue;
            dfs(G,v,u);
            if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
        }
        forp(u,G)
        {
            int v(G.to[p]); if(v==fa)continue;
            if(v!=son[u])istop[v]=true;
        }
    }
    void run(Graph& G, int root)
    {
        tot=0;
        depth[0]=0, dfs(G,root,0);
        istop[root]=true;
    }
}lcd;
int getpre(int u, int L, int k)
{
    if(L+1<lcd.len[u])return c[pos[u]+0][k] - c[pos[u]+(L+1)][k];
    return c[pos[u]+0][k];
}
void dfs(int u, int fa)
{
    int i, k;
    if(lcd.istop[u])
    {
        pos[u] = tot;
        tot += lcd.len[u];
    }
    c[pos[u]][b[u]]++;
    if(lcd.son[u])
    {
        auto v(lcd.son[u]);
        pos[v] = pos[u]+1;
        dfs(v,u);
        rep(k,0,3)c[pos[u]][k]+=c[pos[v]][k];
    }
    forp(u,G)
    {
        int v=G.to[p];
        if(v==fa or v==lcd.son[u])continue;
        dfs(v,u);
        rep(k,0,3)c[pos[u]][k]+=c[pos[v]][k];
        rep(i,0,lcd.len[v]-1)
            rep(k,0,3)c[pos[u]+(i+1)][k] += c[pos[v]+i][k];
    }
    rep(k,0,1)ans[u]+=(ull)getpre(u,K,k)*getpre(u,K,3^k)<<(k1+k2);
}
int main()
{
    int i, fa, j;
    n = read(), K=read();
    rep(i,1,n)a[i]=read();
    rep(i,2,n)
    {
        fa=read();
        G.adde(fa,i);
    }
    lcd.run(G,1);
    rep(k1,0,29)rep(k2,k1+1,29)
    {
        rep(i,1,n)
        {
            b[i]=0;
            if(a[i]&(1<<k1))b[i]|=1;
            if(a[i]&(1<<k2))b[i]|=2;
        }
        rep(i,0,n)rep(j,0,3)c[i][j]=0;
        tot=0;
        dfs(1,0);
    }
    rep(i,1,n)ans[i]<<=1;
    rep(k1,0,29)
    {
        k2=k1;
        rep(i,1,n)
        {
            b[i]=0;
            if(a[i]&(1<<k1))b[i]|=1;
            if(a[i]&(1<<k2))b[i]|=2;
        }
        rep(i,0,n)rep(j,0,3)c[i][j]=0;
        tot=0;
        dfs(1,0);
    }
    rep(i,1,n)printf("%llu\n",ans[i]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/FSAHFGSADHSAKNDAS/article/details/106184680