#include<iostream>
#include<vector>
using namespace std;
class Ntt {
public:
int p;
int w;
int n;
vector<int> arrw;
vector<int> revw;
Ntt(int p, int w, int n):p(p),w(w),n(n) {
arrw.resize(1<< n);
arrw[0] = 1;
for (int i = 1; i < (1<<n); i++) {
arrw[i] = arrw[i - 1] * w % p;
}
revw.resize(1 << n );
revw[0] = 0;
for (int i = 0; i <= (1<<(n-1)); i++)
revw[i] = (revw[i >> 1] >> 1) | ((i & 1) << (n - 2));
}
vector<int> ntt(vector<int> a) {
int k = 0;
for (int mid = (1<<(n-2)); mid > 0; mid>>=1) {
for (int start = 0; start < (1<<(n-1)); start+=mid*2) {
int wk = arrw[revw[++k]];
for (int pos = start; pos < start + mid; pos++) {
int l = a[pos];
int r = a[pos + mid] * wk % p;
a[pos] = (l + r) % p;
a[pos + mid] = (l - r) % p;
if (a[pos] < 0)a[pos] += p;
if (a[pos + mid] < 0)a[pos+mid] += p;
cout << wk << endl;
for (auto i : a)cout << i << ' ';
cout << endl << endl;
}
}
}
return a;
}
vector<int> rntt(vector<int> a) {
int k = (1<<(n-1));
for (int mid = 1; mid <= (1 << (n - 2)); mid <<= 1) {
for (int start = (1 << (n - 1)) - mid*2; start >= 0; start -= mid * 2) {
int wk = arrw[((1<<n))-revw[--k]];
for (int pos = start; pos<start+mid; pos++) {
//cout <<"start: "<<start<<" pos: "<< pos << " mid: " << mid << endl;
int add = (a[pos] + a[pos + mid]) % p ;
int minus = (a[pos] - a[pos + mid])*wk % p ;
if (add < 0)add += p;
if (minus < 0)minus += p;
if (add % 2)add += p;
a[pos] = add/2;
if (minus % 2)minus += p;
a[pos + mid] = minus/2;
cout << wk << endl;
for (auto i : a)cout << i << ' ';
cout << endl << endl;
}
}
}
return a;
}
};
int main() {
Ntt tmp(17, 2, 3);
auto vec=tmp.rntt(tmp.ntt({
1,4,3,13}));
for (auto i : vec) {
if (i < 0)
cout << i + 17 << ' ';
else
cout << i << ' ';
}cout << endl;
}
负折叠卷积ntt变换
Guess you like
Origin blog.csdn.net/weixin_39057744/article/details/121496064
Ranking