多项式(2)---分治FFT,

注意:这是个人学习笔记,如果有人因为某些原因点了进来并且要看一下,请一定谨慎地阅读,因为可能存在各种奇怪的错误(甚至可能概念都有问题...),如果有人发现错误请指出谢谢!


分治FFT

题目:洛谷P4721 【模板】分治 FFT

前置问题:如何对任意已知f,g,l<=r,l1<=r1,对于所有$l1<=k<=r1$,求$\sum_{i=l}^rf_ig_j[i+j=k]$(①)?

答案:
设$f'_i=f_{i+l}$,$k'=k-l$则$l1-l<=k'<=r1-l$
①式$=\sum_{i=0}^{r-l}f'_ig_j[i+j=k']$

设$g'_i=g_{i+l1-r}$,$k''=k'-l1+r$则$r-l<=k''<=r1-l1+r-l$
①式$=\sum_{i=0}^{r-l}f'_ig'_{j-l1+r}[i+j-l1+r=k'']$
$=\sum_{i=0}^{r-l}f'_ig'_j[i+j=k'']$

把$f'$的第r-l之后的项都赋值为0,然后对$f'$和$g'$卷积就可以了,取出结果中第r-l到r1-l1+r-l项作为答案即可。注意到只需要f的第l到r项,g的第l1-r到r1-l项,卷积结果的第r-l到r1-l1+r-l项,因此复杂度O((r1-l1+r-l)log(r1-l1+r-l))

此题:$f_k=\sum_{i=0}^{k-1}f_ig_{k-i}$;$f_0=1$

考虑分治。各个序列中不存在的项全部当成是0

solve(l,r):在l左边的f值,以及对于所有$l<=k<=r$,$\sum_{i=0}^{l-1}f_ig_{k-i}$,都已经正确求出来时,求出f[l]到f[r]的值。

先solve(l,mid),再计算[l,mid]对[mid+1,r]的贡献,再solve(mid+1,r)

计算[l,mid]对[mid+1,r]的贡献,就相当于要对于所有$mid+1<=k<=r$,计算$\sum_{i=l}^{mid}f_ig_j[i+j=k]$

用上面的方法完成即可

附:这题里面,快读快写基本没用;小范围暴力有用,以下代码大概在开O2以后开到r-l<=K,K在100到200左右时进行暴力比较合适(大概是FFT常数真的大吧...)

  1 #pragm\
  2 a GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 using namespace std;
  8 #define fi first
  9 #define se second
 10 #define mp make_pair
 11 #define pb push_back
 12 typedef long long ll;
 13 typedef unsigned long long ull;
 14 
 15 const ll md=998244353;
 16 ll poww(ll a,ll b)
 17 {
 18     ll base=a,ans=1;
 19     for(;b;b>>=1,base=base*base%md)
 20         if(b&1)
 21             ans=ans*base%md;
 22     return ans;
 23 }
 24 const int N=131073;
 25 int n,n1;ll g[N],f[N],t1[N],t2[N];
 26 int rev[N];
 27 void init(int len)
 28 {
 29     int bit=0,i;
 30     while((1<<(bit+1))<=len)    ++bit;
 31     for(i=0;i<len;++i)
 32         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 33 }
 34 void dft(ll *a,int len,int idx)
 35 {
 36     int i,j,k;ll wn,wnk,t1,t2;
 37     for(i=0;i<len;++i)
 38         if(i<rev[i])
 39             swap(a[i],a[rev[i]]);
 40     for(i=1;i<len;i<<=1)
 41     {
 42         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 43         for(j=0;j<len;j+=(i<<1))
 44         {
 45             wnk=1;
 46             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 47             {
 48                 t1=a[k];t2=a[k+i]*wnk%md;
 49                 a[k]+=t2;a[k+i]=t1-t2;
 50                 (a[k]>=md) && (a[k]-=md);
 51                 (a[k+i]<0) && (a[k+i]+=md);
 52             }
 53         }
 54     }
 55     if(idx==-1)
 56     {
 57         ll ilen=poww(len,md-2);
 58         for(i=0;i<len;++i)
 59             (a[i]*=ilen)%=md;
 60     }
 61 }
 62 void solve(int l,int r)
 63 {
 64     int i,j;
 65     if(r-l<=128)
 66     {
 67         for(i=l;i<=r;++i)
 68         {
 69             for(j=l;j<i;++j)
 70             {
 71                 (f[i]+=f[j]*g[i-j])%=md;
 72             }
 73         }
 74         return;
 75     }
 76     int mid=(l+r)>>1,len=r-l+1;
 77     solve(l,mid);
 78     memcpy(t1,f+l,sizeof(ll)*(mid-l+1));
 79     memcpy(t2,g+1,sizeof(ll)*(r-l));
 80     memset(t1+mid-l+1,0,sizeof(ll)*(r-mid));
 81     init(len);
 82     dft(t1,len,1);
 83     dft(t2,len,1);
 84     for(i=0;i<len;++i)
 85         (t1[i]*=t2[i])%=md;
 86     dft(t1,len,-1);
 87     for(i=mid+1;i<=r;++i)
 88     {
 89         f[i]+=t1[i-1-l];
 90         (f[i]>=md) && (f[i]-=md);
 91     }
 92     solve(mid+1,r);
 93 }
 94 int main()
 95 {
 96     int i,t;
 97     scanf("%d",&n);n1=n;
 98     for(i=1;i<n;++i)
 99         scanf("%lld",g+i);
100     for(t=1;t<n;t<<=1);
101     n=t;
102     f[0]=1;
103     solve(0,n-1);
104     for(i=0;i<=n1-1;++i)
105         printf("%lld ",f[i]);
106     return 0;
107 }
View Code

猜你喜欢

转载自www.cnblogs.com/hehe54321/p/10331983.html