Little Pony and Elements of Harmony(CF 453 D)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhouyuheng2003/article/details/84791844

1 题目相关

1.1传送门

传送门

1.2 题目大意

给你一个 m m ,现在有一个数组 f f ,已知这 2 m 2^m 个数在 0 0 时刻的值 f [ 0 ] [ i ] ( i [ 0 , 2 m ) ) f[0][i](i\in[0,2^m)) ,已知另一个数组 b [ i ] ( i [ 0 , m ] ) b[i](i\in[0,m])
c n t ( x ) cnt(x) x x 在二进制下1的数量, \oplus 为异或运算
有如下递推式 f [ i ] [ j ] = k = 0 2 m 1 b [ c n t [ j k ] ] f [ i 1 ] [ k ] m o d    p f[i][j]=\sum_{k=0}^{2^m-1}b[cnt[j\oplus k]]*f[i-1][k] \mod p
f [ t ] [ i ] ( i [ 0 , 2 m ) ) f[t][i](i\in[0,2^m))

1.3 数据范围

原题数据范围:
1 m 20 1 t 1 0 18 2 p 1 0 9 1\le m\le20,1\le t\le 10^{18},2\le p\le10^9
1 f [ 0 ] [ i ] , b [ i ] 1 0 9 1\le f[0][i],b[i]\le 10^9
m , t , p , f , b Z m,t,p,f,b\in \mathbb{Z}
时间限制 6 s 6s
为了解释方便,定义 n = 2 m n=2^m

2 算法

2.1 暴力

拿到一道题,首先当然先考虑暴力怎么做,当然是很方便的,首先,我们可以直接考虑模拟,总复杂度 Θ ( n 2 t m ) \Theta(n^2tm)

2.2 稍有优化的暴力

容易发现,那个 c n t [ ] cnt[] 是可以预处理的,所以复杂度变为 Θ ( n 2 t ) \Theta(n^2t)

2.3 换个思路

你发现, t t 实在是太大了,实在为本题的一大瓶颈
考虑矩阵乘法,对于下一个时刻本质上是进行一次矩阵乘法,总复杂度 Θ ( n 3 l o g t ) \Theta(n^3logt)

2.4 你会FWT

前面的那么多算法的复杂度并不优越,一般的出题人对于前面的算法的给分一般都只有一档,实在不是很可观
但是,你会FWT,你看过我的博客,(不会FWT的请点击快速沃尔什变换(FWT))
请仔细观察如下式子
f [ i ] [ j ] = k = 0 2 m 1 b [ c n t [ j k ] ] f [ i 1 ] [ k ] m o d    p f[i][j]=\sum_{k=0}^{2^m-1}b[cnt[j\oplus k]]*f[i-1][k] \mod p
把它做一个转换
f [ i ] [ j ] = a b = j b [ c n t [ a ] ] f [ i 1 ] [ b ] m o d    p f[i][j]=\sum_{a\oplus b=j}b[cnt[a]]*f[i-1][b] \mod p
考虑构造一个数组 g [ i ] = b [ c n t [ i ] ] g[i]=b[cnt[i]] ,下一个时刻就相当于对 g g 做一次异或卷积,好吧,目前的复杂度是 Θ ( t n l o g n ) \Theta(tnlogn)

2.5 预处理

由于每次卷积都是卷同一个东西,所以预处理 g 1 , g 2 , g 3 , g 4 g^1,g^2,g^3,g^4···
即快速幂,复杂度 Θ ( n l o g n l o g t ) \Theta(nlognlogt)

2.6 减少FWT、IDFT次数

我们回到2.4,发现复杂度瓶颈在 F W T FWT I F W T IFWT ,然而并不需要每次 F W T ( ) FWT() 数组对应系数乘后IDFT回来,直接在里面乘就好了,复杂度 Θ ( n l o g n + n t ) \Theta(nlogn+nt)

2.7 快速幂优化

发现2.6里的乘系数每次都是乘同一个数,快速幂一下,复杂度 Θ ( n l o g n + n l o g t ) \Theta(nlogn+nlogt)

2.8 Tip

发现这题有一个很坑的地方,就是p可能是2的倍数,而FWT的时候要除以n,然后就会挂
例如: n = 2 n=2 , p = 6 p=6 ,现在有个数当前算出来值为 8 8 ,直接除以 2 2 结果是 4 4 ,先模 p p 结果为 8 8%6=2 ,再除以 2 2 为结果为 1 1 ,就挂了
然后就需要将模数扩大 n n
n p n*p 1 0 19 10^{19} 两个数直接乘会爆long long,慢速乘又会带来 Θ ( l o g n ) \Theta(logn) 的复杂度,不是很优越,所以需要一些奇技淫巧优化这个乘法

inline LL msc(LL a,LL b)
{
    LL v=(a*b-(LL)((long double)a/mod*b+1e-8)*mod);
    return v<0?v+mod:v;
}
代码

到这里这个算法已经可以通过此题了,贴出AC代码

#include<cstdio>
#include<cctype>
#include<algorithm>
#define rg register
typedef long long LL;
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline void mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline void maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T lcm(const T a,const T b){return a/gcd(a,b)*b;}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;
}
template <typename T> inline void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
LL mod=998244353;const int lenth=1048576;
inline LL msc(LL a,LL b)
{
    LL v=(a*b-(LL)((long double)a/mod*b+1e-8)*mod);
    return v<0?v+mod:v;
}
inline LL pow(LL a,LL b)
{
    LL res=1;
    for(;b;a=msc(a,a),b>>=1)if(b&1)res=msc(a,res);
    return res;
}
LL n,m,t,a[lenth],b[lenth],c[21];int cnt[lenth];
inline void FWT(LL*A,const int fla)
{
    for(rg int i=1;i<n;i<<=1)
        for(rg int j=0;j<n;j+=(i<<1))
            for(rg int k=0;k<i;k++)
            {
                const LL x=A[j+k],y=A[j+k+i];
                A[j+k]=(x+y)%mod;
                A[j+k+i]=(x+mod-y)%mod;
            }
    if(fla==-1)
        for(rg int i=0;i<n;i++)
            A[i]/=n;
}
int main()
{
    read(m),read(t),read(mod),n=1<<m,mod*=n;
    for(rg int i=0;i<n;i++)read(a[i]);
    for(rg int i=1;i<n;i++)cnt[i]=cnt[i^(i&-i)]+1;
    for(rg int i=0;i<=m;i++)read(c[i]);
    for(rg int i=0;i<n;i++)b[i]=c[cnt[i]];
    FWT(a,1),FWT(b,1);
    for(rg int i=0;i<n;i++)a[i]=msc(a[i],pow(b[i],t));
    FWT(a,-1);
    for(rg int i=0;i<n;i++)print(a[i]),putchar('\n');
    return 0;
}

3 优化

然而,这道题可以进行优化

3.1 加强版数据范围(毒瘤

1 t 2 10000 1\le t\le 2^{10000}

3.2 思考

对于之前的那个算法,显然是无法接受这个数据范围的,所以我们现在需要换个思路

新定义:
定义两个位置 i i j j 的距离 d i s ( i , j ) = c n t [ i j ] dis(i,j)=cnt[i\oplus j]

对于一个位置 x x ,每次能向它转移的位置 y y 2 m 2^m 个,所以暴力复杂度非常不优。但是你会发现,转移的位置虽然多,但是系数的种类却不多,对 x x 位置产生贡献的系数按与 x x 位置的距离分只有 m + 1 m+1 种,我们就可以观察是否可以在这里进行突破
当然可以,不然我说这个干什么
容易发现,这 m + 1 m+1 种系数就是输入的那个 b [ ] b[] ,而在 t t 时刻,仍然是有一个 b t [ ] b_t[] 的(在之后,这个下标代表时刻),只是数值上发生了变化
那么如果我们能够快速的求出 t t 时刻的b数组,那么我们剩下要做的就是一遍FWT,复杂度只剩 Θ ( n l o g n ) \Theta(nlogn)

3.3 分析

要计算 b t [ ] b_t[] ,我们要从 b t 1 [ ] b_{t-1}[] 推过来
转移显然是 Θ ( m 2 ) \Theta(m^2) 的,那么转移的系数是什么呢?
考虑 b t 1 [ j ] b_{t-1}[j] b t [ i ] b_t[i] 的贡献

  • step1
    把数值字母化,对于一组 x , y , z x,y,z 满足 d i s ( x , y ) = j , d i s ( x , z ) = i dis(x,y)=j,dis(x,z)=i
    b t 1 [ j ] b_{t-1}[j] 相当于 t 1 t-1 时刻 x x y y 的贡献
    b t [ i ] b_t[i] 相当于 t t 时刻 x x z z 的贡献
  • step2
    已知 d i s ( x , y ) = j dis(x,y)=j z z 中对 d i s dis 贡献为 1 1 的位有 j j 位,为 0 0 的有 m j m-j
    我们枚举 y y 变到 z z 的过程中,贡献 1 1 0 0 的位数 s 1 s_1 和贡献 0 0 1 1 的位数 s 1 s_1
    容易发现 j s 1 + s 0 = i j-s_1+s_0=i ,并且 d i s ( y , z ) = s 0 + s 1 dis(y,z)=s_0+s_1 ,贡献系数为 b [ s 0 + s 1 ] C ( j , s 1 ) C ( m j , s 0 ) b[s_0+s_1]*C(j,s_1)*C(m-j,s_0)
    枚举 s 0 s_0 ,那么 s 1 = j i + s 0 s_1=j-i+s_0 ,把式子算出来即可

现在知道转移的系数了,那么就可以用矩阵乘法优化,复杂度是 Θ ( m 3 l o g t ) \Theta(m^3logt) 的,算法总复杂度为 Θ ( m 3 l o g t + n l o g n ) \Theta(m^3logt+nlogn)

3.4 代码

贴出AC代码

#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
#define rg register
typedef long long LL;
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline void mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline void maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T lcm(const T a,const T b){return a/gcd(a,b)*b;}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;
}
template <typename T> inline void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
LL mod=998244353;const int lenth=1048576;
inline LL msc(LL a,LL b)
{
    LL v=(a*b-(LL)((long double)a/mod*b+1e-8)*mod);
    return v<0?v+mod:v;
}
const int maxn=21;
int size=21;
struct Martix
{
    LL a[maxn][maxn];
    Martix(){};
    Martix(const int x)
    {
        if(x==0)memset(a,0,sizeof(a));
        else if(x==1)
        {
            memset(a,0,sizeof(a));
            for(rg int i=0;i<size;i++)a[i][i]=1;
        }
    }
    Martix operator *(const Martix&b)const
    {
        Martix x=0;
        for(rg int i=0;i<size;i++)
            for(rg int j=0;j<size;j++)
                for(rg int k=0;k<size;k++)
                    x.a[i][j]=(x.a[i][j]+msc(a[i][k],b.a[k][j]))%mod;
        return x;
    }
}bz,final;
template <typename T,typename sum>inline T pow(T x,const sum y)
{
    T res=1;
    for(rg sum i=1;i<=y;i<<=1,x=x*x)if(i&y)res=res*x;
    return res;
}
LL n,m,t,a[lenth],b[lenth],c[21];int cnt[lenth];
inline void FWT(LL*A,const int fla)
{
    for(rg int i=1;i<n;i<<=1)
        for(rg int j=0;j<n;j+=(i<<1))
            for(rg int k=0;k<i;k++)
            {
                const LL x=A[j+k],y=A[j+k+i];
                A[j+k]=(x+y)%mod;
                A[j+k+i]=(x+mod-y)%mod;
            }
    if(fla==-1)
        for(rg int i=0;i<n;i++)
            A[i]/=n;
}
LL C[21][21];
int main()
{
    read(m),read(t),read(mod),n=1<<m,mod*=n;
    for(rg int i=0;i<n;i++)read(a[i]);
    for(rg int i=1;i<n;i++)cnt[i]=cnt[i^(i&-i)]+1;
    for(rg int i=0;i<=m;i++)read(c[i]);final.a[0][0]=1;
    for(rg int i=0;i<=m;i++)
    {
        C[i][0]=1;
        for(rg int j=1;j<i;j++)C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
        C[i][i]=1;
    }
    for(rg int j=0;j<=m;j++)
        for(rg int i=0;i<=m;i++)
        {
            for(rg int s0=0;s0<=m;s0++)
            {
                const int s1=j-i+s0;
                if(s0>m-j)continue;
                if(s1>j||s1<0)continue;
                bz.a[i][j]=(bz.a[i][j]+msc(msc(c[s0+s1],C[m-j][s0]),C[j][s1]))%mod;
            }
        }
    final=final*pow(bz,t);
    for(rg int i=0;i<n;i++)b[i]=final.a[0][cnt[i]];
    FWT(a,1),FWT(b,1);
    for(rg int i=0;i<n;i++)a[i]=msc(a[i],b[i]);
    FWT(a,-1);
    for(rg int i=0;i<n;i++)print(a[i]),putchar('\n');
    return 0;
}

4 总结

这是一道非常不错的题,想要AC此题并不是很难,我写这题的博客是因为它的优化非常巧妙,也没有用到很难的知识点,值得深思
还有要把慢速乘背下来

猜你喜欢

转载自blog.csdn.net/zhouyuheng2003/article/details/84791844
今日推荐