Wannafly Winter Camp Day5 Div1 E题 Fast Kronecker Transform 转化为NTT或FFT

目录

(有任何问题欢迎留言或私聊 && 欢迎交流讨论哦

Catalog

@

Problem:传送门

 原题目描述在最下面。
 对给定的式子算解。
\(0\leq k\leq n+m,c_k=(\sum_{i+j=k}i\times j\times \sigma_{a_i,b_j}) mod\;998244353\),其中\(当且仅当a=b时,\sigma_{a_i,b_j}=1。\)

Solution:

 我们发现只有当\(a_i\)\(b_j\)相等时才会对答案造成贡献。

 一共\(2e^5\)个数字,我们枚举每一个数字算贡献,同时分情况讨论:当这个
数字出现次数小于阀值\(T\)时,我们\(O(n^2)\)暴力算;当出现次数大于\(T\)时,我们用\(FFTorNTT\)计算。

 听说这题有人\(FFT\)丢精度就\(wa\)了,我就干脆用\(NTT\)写了,刚好这个模数也是费马质数嘛,接下来看怎么把上述奇怪的卷积转化成一个可以用\(FFT\)\(NTT\)计算的卷积。

 我们把\(a\)数组和\(b\)数组中数字\(x\)所有出现的位置提出来放到新数组里面去,比如\(x\)出现在\(a\)\(1,3,4\)位置,也出现在\(b\)\(2,4,5,6\)位置。

 当\(k\)等于\(5\)时,数字\(x\)产生的贡献\(1*4,2*3\)
 当\(k\)等于\(6\)时,数字\(x\)产生的贡献\(1*5\)

 我们知道一般的卷积式子是长这样的:\(C_k=\sum_{i+j=k}A_i\times B_j\)

 哇哦,感觉这个式子和那个贡献好相似啊。

 我们就构造这样两个数组\(A_{a_i}=a_i,B_{b_j}=b_j\)可以得到:
\(A[]=\{0,1,0,3,4\},B[]=\{0,0,2,0,4,5,6\}\)

 我们对这两个数组求他们多项式乘法的结果,然后将结果的贡献累计到答案数组中就行啦。

 然后当阀值\(T\)\(10000\)时,复杂度就可以承受了。

AC_Code:

感谢日天的板子
这题复杂度的分析dls讲题时分析的很清楚啦,一位群友的笔记截图在下面

#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
namespace lh {
#define o2(x) (x)*(x)
    using namespace std;
    typedef long long LL;
    typedef unsigned long long uLL;
    typedef pair<int, LL> pii;
}

using namespace lh;
const int MX = 2e5 + 5;
//const int P = (479 << 21) + 1;
const int P = 998244353;
const int MOD = 998244353;
const int G = 3;
const int NUM = 20;
struct my_NTT {
    LL wn[NUM];
    LL a[MX << 1], b[MX << 1];
    LL pow (LL a, LL x, LL mod) {
        LL ans = 1;
        a %= mod;
        while (x) {
            if (x & 1) ans = ans * a % mod;
            x >>= 1;
            a = a * a % mod;
        }
        return ans;
    }
    //在程序的开头就要放
    void init() {
        for (int i = 0; i < NUM; i++) {
            int t = 1 << i;
            wn[i] = pow (G, (P - 1) / t, P);
        }
    }
    void Rader (LL F[], int len) {
        int j = len >> 1;
        for (int i = 1; i < len - 1; i++) {
            if (i < j) swap (F[i], F[j]);
            int k = len >> 1;
            while (j >= k) j -= k, k >>= 1;
            if (j < k) j += k;
        }
    }
    void NTT (LL F[], int len, int t) {
        Rader (F, len);
        int id = 0;
        for (int h = 2; h <= len; h <<= 1) {
            id++;
            for (int j = 0; j < len; j += h) {
                LL E = 1;
                for (int k = j; k < j + h / 2; k++) {
                    LL u = F[k];
                    LL v = E * F[k + h / 2] % P;
                    F[k] = (u + v) % P;
                    F[k + h / 2] = (u - v + P) % P;
                    E = E * wn[id] % P;
                }
            }
        }
        if (t == -1) {
            for (int i = 1; i < len / 2; i++) swap (F[i], F[len - i]);
            LL inv = pow (len, P - 2, P);
            for (int i = 0; i < len; i++) F[i] = F[i] * inv % P;
        }
    }
    void Conv (LL a[], LL b[], int len) {
        NTT (a, len, 1);
        NTT (b, len, 1);
        for (int i = 0; i < len; i++) a[i] = a[i] * b[i] % P;
        NTT (a, len, -1);
    }
    void gao (LL A[], LL B[], int n, int m, LL ans[]) {//0~n-1
        int len = 1;
        while (len < n + m) len <<= 1;
        for (int i = 0; i < n; i++) a[i] = A[i];
        for (int i = 0; i < m; i++) b[i] = B[i];
        for (int i = n; i < len; i++) a[i] = 0;
        for (int i = m; i < len; i++) b[i] = 0;
        Conv (a, b, len);
        for (int i = 0; i < len; i++) ans[i] = (ans[i]+a[i])%MOD;
    }
}ntt;
const int MXN = 2e5 + 5;
int n, m;
int ar[MXN], br[MXN];
LL A[MXN], B[MXN];
std::vector<int> all[MXN], bll[MXN];
LL ans[MXN];
void solve1(int id) {
    for(int i = 0; i < all[id].size(); ++i) {
        for(int j = 0; j < bll[id].size(); ++j) {
            ans[all[id][i]+bll[id][j]] += (LL)all[id][i] * bll[id][j];
            ans[all[id][i]+bll[id][j]] %= MOD;
        }
    }
}
void solve2(int id) {
    for(int i = 0; i <= n+m; ++i) A[i] = B[i] = 0;
    for(int i = 0; i < all[id].size(); ++i) A[all[id][i]] = all[id][i];
    for(int i = 0; i < bll[id].size(); ++i) B[bll[id][i]] = bll[id][i];
    ntt.gao(A, B, all[id].back()+1, bll[id].back()+1, ans);
}
int main(int argc, char const *argv[]) {
    scanf("%d%d", &n, &m); ++n, ++m;
    ntt.init();
    std::vector<int> vs;
    for(int i = 0; i < n; ++i) scanf("%d", &ar[i]), vs.push_back(ar[i]);
    for(int i = 0; i < m; ++i) scanf("%d", &br[i]), vs.push_back(br[i]);
    sort(vs.begin(), vs.end());
    vs.erase(unique(vs.begin(), vs.end()), vs.end());
    for(int i = 0, tmp; i < n; ++i) {
        tmp = lower_bound(vs.begin(), vs.end(), ar[i]) - vs.begin();
        all[tmp].push_back(i);
    }
    for(int i = 0, tmp; i < m; ++i) {
        tmp = lower_bound(vs.begin(), vs.end(), br[i]) - vs.begin();
        bll[tmp].push_back(i);
    }
    for(int i = 0; i < vs.size(); ++i) {
        if(all[i].size() + bll[i].size() <= 10000) solve1(i);
        else solve2(i);
    }
    for(int i = 0; i <= n + m-2; ++i) printf(i!=n+m-2?"%lld ":"%lld\n", ans[i]);
    return 0;
}


Problem Description:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自www.cnblogs.com/Cwolf9/p/10322102.html