Codecraft-18 and Codeforces Round #458 (Div. 1 + Div. 2, combined)G. Sum the Fibonacci

题意:给一个数组s,求\(f(s_a | s_b) * f(s_c) * f(s_d \oplus s_e)\),f是斐波那契数列,而且要满足\(s_a\&s_b==0\),\((s_a | s_b)\&s_c\&(s_d \oplus s_e)=2^{i}\)
题解:先求\(A_k=f(k)*\sum_{i|j==k\&\&i\&j==0}s_a*s_b\),明显是个子集卷积,在求出\(B_k=f(k)*s_k\),\(C_k=f(k)*\sum_{i \oplus j==k}s_i*s_j\),C明显是个xor卷积,fwt即可.
最后是\(D_l=\sum_{i\&j\&k==l}A_i*B_j*C_k\),D明显是个and卷积,还是fwt.答案就是\(\sum D(2^{i})\)
子集卷积可以枚举子集在\(O(3^{17})\)时间算出来.也可以通过fmt求出,dp[i][j]表示集合大小为i的j集合答案.对dp[i]单独fmt,时间复杂度\(O(17^{2}*2^{17})\)

//#pragma GCC optimize(2)
//#pragma GCC optimize(3)
//#pragma GCC optimize(4)
//#pragma GCC optimize("unroll-loops")
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>
#define fi first
#define se second
#define db double
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define ld long double
//#define C 0.5772156649
//#define ls l,m,rt<<1
//#define rs m+1,r,rt<<1|1
#define pll pair<ll,ll>
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
#define ull unsigned long long
//#define base 1000000000000000000
#define fin freopen("a.txt","r",stdin)
#define fout freopen("a.txt","w",stdout)
#define fio ios::sync_with_stdio(false);cin.tie(0)
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void sub(ll &a,ll b){a-=b;if(a<0)a+=mod;}
inline void add(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
template<typename T>inline T const& MAX(T const &a,T const &b){return a>b?a:b;}
template<typename T>inline T const& MIN(T const &a,T const &b){return a<b?a:b;}
inline ll qp(ll a,ll b){ll ans=1;while(b){if(b&1)ans=ans*a%mod;a=a*a%mod,b>>=1;}return ans;}
inline ll qp(ll a,ll b,ll c){ll ans=1;while(b){if(b&1)ans=ans*a%c;a=a*a%c,b>>=1;}return ans;}

using namespace std;

const ull ba=233;
const db eps=1e-5;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int N=(1<<17)+10,maxn=1000000+10,inf=0x3f3f3f3f;

int a[N],b[N],c[N],d[20][N],dp[20][N],f[N];
int inv2=qp(2,mod-2);
void fwt_or(int *a,int n,int dft)
{
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=i<<1)
            for(int k=j;k<j+i;k++)
            {
                if(dft==1)a[i+k]=(a[i+k]+a[k])%mod;
                else a[i+k]=(a[i+k]-a[k]+mod)%mod;
            }
}

void fwt_and(int *a,int n,int dft)
{
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=i<<1)
            for(int k=j;k<j+i;k++)
            {
                if(dft==1)a[k]=(a[k]+a[i+k])%mod;
                else a[k]=(a[k]-a[i+k]+mod)%mod;
            }
}
void fwt_xor(int *a,int n,int dft)
{
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=i<<1)
            for(int k=j;k<j+i;k++)
            {
                int x=a[k],y=a[i+k];
                a[k]=(x+y)%mod;a[i+k]=(x-y+mod)%mod;
                if(dft==-1)a[k]=1ll*a[k]*inv2%mod,a[i+k]=1ll*a[i+k]*inv2%mod;
            }
}
int main()
{
    f[0]=0,f[1]=1;
    for(int i=2;i<N;i++)
    {
        f[i]=f[i-1]+f[i-2];
        if(f[i]>=mod)f[i]-=mod;
    }
    int n;scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        int x;scanf("%d",&x);
        a[x]++,c[x]++;
    }
    for(int i=0;i<(1<<17);i++)d[__builtin_popcount(i)][i]=a[i];
    for(int i=0;i<=17;i++)fwt_or(d[i],(1<<17),1);
    for(int i=0;i<=17;i++)for(int j=0;j<=i;j++)
    for(int k=0;k<(1<<17);k++)
    {
        dp[i][k]+=1ll*d[j][k]*d[i-j][k]%mod;
        if(dp[i][k]>=mod)dp[i][k]-=mod;
    }
    for(int i=0;i<=17;i++)fwt_or(dp[i],(1<<17),-1);
    for(int i=0;i<(1<<17);i++)b[i]=dp[__builtin_popcount(i)][i];
    fwt_xor(c,(1<<17),1);
    for(int i=0;i<(1<<17);i++)c[i]=1ll*c[i]*c[i]%mod;
    fwt_xor(c,(1<<17),-1);
    for(int i=0;i<(1<<17);i++)
    {
        a[i]=1ll*a[i]*f[i]%mod;
        b[i]=1ll*b[i]*f[i]%mod;
        c[i]=1ll*c[i]*f[i]%mod;
    }
    fwt_and(a,(1<<17),1);fwt_and(b,(1<<17),1);fwt_and(c,(1<<17),1);
    for(int i=0;i<(1<<17);i++)a[i]=1ll*a[i]*b[i]%mod*c[i]%mod;
    fwt_and(a,(1<<17),-1);
    int ans=0;
    for(int i=0;i<17;i++)
    {
        ans+=a[1<<i];
        if(ans>=mod)ans-=mod;
    }
    printf("%d\n",ans);
    return 0;
}
/********************

********************/

猜你喜欢

转载自www.cnblogs.com/acjiumeng/p/10649424.html