UOJ#34. Polynomial Multiplication (NTT)

This is a template question.

Given two polynomials, please output the multiplied polynomials.

input format

The first line contains two integers  n n and  m m, which respectively represent the degree of the two polynomials.

The second line  contains n + 1 n+1 integers representing the coefficients of the first polynomial of degree  0 0 to  n n.

The third line is  m + 1 m+1 integers representing the coefficients of the second polynomial of degree  0 0 to  m m.

output format

A row of  n + m + 1 n+m+1 integers representing the coefficients of the multiplied polynomials of degree  0 0 to  n + m n+m.

Example 1

input

1 2
1 2
1 2 1

output

1 4 5 2

explanation

(1+2x)(1+2x+x2)=1+4x+5x2+2x3(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3。

Restrictions and Conventions

0 n , m 10 5 0≤n,m≤105, ensure that the coefficient in the input is greater than or equal to  0 0 and less than or equal to  9 9.

Time limit : 1 s 1s

Space Limit : 256 MB

 

Shock!

The reason for TLE all morning turned out to be that the definitions of prime numbers and primitive roots did not add const!

NTT board questions

Just replace the unit element with the original root

#include<cstdio>
#include<algorithm>
#include<cmath>
#define swap(x,y) x ^= y, y ^= x, x ^= y
#define LL long long 
using namespace std;
const int MAXN = 3 * 1e6 + 10;
inline int read(){
    int x=0,f=1;char ch=' ';
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
int N, M, limit = 1 , L;
const  int P = 998244353 , G = 3 , Gi = 332748118 ;
LL a[MAXN], b[MAXN];
int r[MAXN];
inline LL fastpow(LL a, LL k) {
    LL base = 1;
    while(k) {
        if(k & 1) base = (base * a ) % P;
        a = (a * a) % P;
        k >>= 1;
    }
    return base % P;
}
inline void NTT(LL *A, int type) {
    for(int i = 0; i < limit; i++) 
        if(i < r[i]) swap(A[i], A[r[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) {    
        LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));
        for(int j = 0; j < limit; j += (mid << 1)) {
            LL w = 1;
            for(int k = 0; k < mid; k++, w = (w * Wn) % P) {
                 int x = A[j + k], y = w * A[j + k + mid] % P;
                 A[j + k] = (x + y) % P,
                 A[j + k + mid] = (x - y + P) % P;
            }
        }
    }
}
int main() {
    #ifdef WIN32
    freopen("a.in", "r", stdin);
    #endif
    N = read(); M = read();
    for(int i = 0; i <= N; i++) a[i] = (read() + P) % P;
    for(int i = 0; i <= M; i++) b[i] = (read() + P) % P;
    while(limit <= N + M) limit <<= 1, L++;
    for(int i = 0; i < limit; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));    
    NTT(a, 1);NTT(b, 1);    
    for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
    NTT(a, - 1 );    
    LL inv = fastpow(limit, P - 2);
    for(int i = 0; i <= N + M; i++)
        printf("%d ", (a[i] * inv) % P);
    return 0;
}

 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325133641&siteId=291194637