P4238 【模板】多项式求逆

P4238 【模板】多项式求逆

链接

分析:

  多项式求逆元

代码:700ms

 1 #include<cstdio>
 2 #include<algorithm>
 3 #include<cstring>
 4 #include<cmath>
 5 #include<iostream>
 6 
 7 using namespace std;
 8 
 9 typedef long long LL;
10 
11 const int N = 2100000;
12 const int P = 998244353;
13 const int G = 3;
14 const int Gi = 332748118;
15 int A[N],B[N],TA[N],TB[N];
16 
17 inline int read() {
18     int x = 0,f = 1;char ch=getchar();
19     for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1;
20     for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0';
21     return x*f;
22 }
23 int ksm(int a,int b) {
24     int ans = 1;
25     while (b) {
26         if (b & 1) ans = (1ll * ans * a) % P;
27         a = (1ll * a * a) % P;
28         b >>= 1;
29     }
30     return ans % P;
31 }
32 void NTT(int *a,int n,int ty) {
33     for (int i=0,j=0; i<n; ++i) {
34         if (i < j) swap(a[i],a[j]);
35         for (int k=(n>>1); (j^=k)<k; k>>=1);
36     }
37     for (int w1,w,m=2; m<=n; m<<=1) {
38         if (ty==1) w1 = ksm(G,(P-1)/m);
39         else w1 = ksm(Gi,(P-1)/m);
40         for (int i=0; i<n; i+=m) {
41             w = 1;
42             for (int k=0; k<(m>>1); ++k) {
43                 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P;
44                 a[i+k] = (u + t) % P;
45                 a[i+k+(m>>1)] = (u - t + P) % P;
46                 w = 1ll * w * w1 % P;
47             }
48         }
49     }
50     if (ty==-1) {
51         int inv = ksm(n,P-2);
52         for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P;
53     }
54 }
55 int main() {
56     int n = read(),len = 1;
57     for (int i=0; i<n; ++i) A[i] = read();
58     
59     while (len <= n) len <<= 1;
60     
61     B[0] = ksm(A[0],P-2);
62     for (int m=2; m<=len; m<<=1) {
63         for (int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i];
64         NTT(TA,m<<1,1);
65         NTT(TB,m<<1,1);
66         for (int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B
67         NTT(TA,m<<1,-1);
68         for (int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 
69     }
70     for (int i=0; i<n; ++i) printf("%d ",B[i]);    
71     return 0;
72 }
View Code

感觉优化到不能优化的代码:520ms

 1 #include<cstdio>
 2 #include<algorithm>
 3 #include<cctype>
 4 
 5 #define G 3
 6 #define Gi 332748118
 7 #define N 270000
 8 #define P 998244353
 9 #define LL long long 
10 #define rg register 
11 #define add(a, b) (a + b >= P ? a + b - P : a + b)
12 #define dec(a, b) (a - b <  0 ? a - b + P : a - b)
13 #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2) ? EOF :*p1++)
14     
15 using namespace std;
16 
17 int A[N],B[N],TA[N],TB[N],rev[N],KSMG[N],KSMGI[N];
18 
19 char ch,buf[100000],*p1 = buf,*p2 = buf;;
20 inline int read() {
21     int x = 0,f = 1;char ch=getchar();
22     for (; !isdigit(ch); ch=getchar()) if(ch=='-')f=-1;
23     for (; isdigit(ch); ch=getchar()) x=x*10+ch-'0';
24     return x*f;
25 }
26 char obuf[1<<24], *O=obuf;
27 void print(int x) {
28     if(x > 9) print(x / 10);
29     *O++= x % 10 + '0';
30 }
31 inline int ksm(int a,int b) {
32     int ans = 1;
33     while (b) {
34         if (b & 1) ans = (1ll * ans * a) % P;
35         a = (1ll * a * a) % P;
36         b >>= 1;
37     }
38     return ans % P;
39 }
40 void NTT(int *a,int n,int ty,int L) {
41     for(rg int i=1; i<n; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<L-1);
42     for(rg int i=1; i<n; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
43     for (rg int w1,w,m=2; m<=n; m<<=1) {
44         if (ty==1) w1 = KSMG[m];else w1 = KSMGI[m];
45         for (int i=0; i<n; i+=m) {
46             w = 1;
47             for (rg int k=0; k<(m>>1); ++k) {
48                 int u = a[i+k],t = 1ll * w * a[i+k+(m>>1)] % P;
49                 a[i+k] = add(u, t);
50                 a[i+k+(m>>1)] = dec(u, t);
51                 w = 1ll * w * w1 % P;
52             }
53         }
54     }
55     if (ty==-1) {
56         int inv = ksm(n,P-2);
57         for (rg int i=0; i<n; ++i) a[i] = 1ll * a[i] * inv % P;
58     }
59 }
60 int main() {
61     int n = read(),len = 1;
62     for (rg int i=0; i<n; ++i) A[i] = read();
63     
64     while (len <= n) len <<= 1;
65     int tmp = len << 1;
66     for (rg int i=1; i<=tmp; i<<=1)    
67         KSMG[i] = ksm(G,(P-1)/i),KSMGI[i] = ksm(Gi,(P-1)/i);
68         
69     B[0] = ksm(A[0],P-2);
70     int t = 1;
71     for (rg int m=2; m<=len; m<<=1) { // 求长度为m的逆元 
72         t ++;
73         for (rg int i=0; i<m; ++i) TA[i] = A[i],TB[i] = B[i];
74         NTT(TA,m<<1,1,t);
75         NTT(TB,m<<1,1,t);
76         for (rg int i=0; i<(m<<1); ++i) TA[i] = 1ll*TA[i]*TB[i]%P*TB[i]%P; // A * B * B
77         NTT(TA,m<<1,-1,t);
78         for (rg int i=0; i<m; ++i) B[i] = (1ll*2*B[i]%P-TA[i]+P)%P; // 多项式减法 
79     }
80     for(rg int i = 0; i < n; i++) print(B[i]), *O++ = ' ';
81     fwrite(obuf, O-obuf, 1 , stdout);
82     return 0;
83 }
View Code

猜你喜欢

转载自www.cnblogs.com/mjtcn/p/9155806.html