【LOJ 565】mathematican 的二进制(分治 + FFT)

题目链接:【LOJ 565】mathematican 的二进制

题目大意:有一个 n 位的二进制数, m 个操作。第 i 个操作是将这个二进制串的数值加上 2 a i ,有 p i 的几率被执行。每次操作的代价是这次操作改变的位的数量。求代价的期望值 mod 998244353 的结果。 n , m 2 × 10 5

我们发现:

  • 最终的答案与操作顺序无关,只与哪些操作被执行过有关。
  • 因为每次进位总会让 1 的总个数减少 1 ,总代价就是所有被执行的操作的总次数的两倍减去最终剩下的数中 1 的个数。即: 2 m b i t c o u n t ( a )

于是可以列出递推式: f ( i , j ) 表示从后往前第 i 位总共被改变 j 次的概率,那么我们有两种转移:

  • 进位: f ( i 1 , j ) f ( i , j 2 )
  • 操作:对于第 i 位每个概率为 p 的操作, ( 1 p ) f ( i , j ) + p f ( i 1 , j ) f ( i , j )

发现进位可以直接转移,操作可以用分治 N T T 转移。时间复杂度 O ( m log 2 m )

注意:二进制数的最大值不是 2 n ,而是 m 2 n

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int maxn = 200025;
const int maxm = 1 << 19 | 5;
int n, m, r[maxm], cur;
ll temp[2][maxm];
vector<ll> f, g, v[maxn], p[maxm];
ll mpow(ll a, ll b, ll c) {
    if ((b %= mod - 1) < 0) {
        b += mod - 1;
    }
    ll d = 1;
    for (; b; b >>= 1, a = a * a % c) {
        if (b & 1) {
            d = d * a % c;
        }
    }
    return d;
}
void ntt(ll *a, int n, int opt) {
    for (int i = 0; i < n; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < n; k <<= 1) {
        ll v = mpow(3, (mod - 1) / (k << 1) * opt, mod);
        for (int i = 0; i < n; i += k << 1) {
            ll w = 1;
            for (int j = i; j < i + k; j++, w = w * v % mod) {
                ll x = a[j], y = w * a[j + k] % mod;
                a[j] = (x + y) % mod, a[j + k] = (x - y) % mod;
            } 
        }
    }
    if (opt == -1) {
        ll v = -(mod - 1) / n;
        for (int i = 0; i < n; i++) {
            a[i] = v * a[i] % mod;
        }
    }
}
void mult(const vector<ll> &f, const vector<ll> &g, vector<ll> &h) {
    int lim, bit = 0, len = f.size() + g.size() - 1;
    for (lim = 1; lim < len; lim <<= 1) bit++;
    for (int i = 0; i < lim; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
    for (int i = 0; i < f.size(); i++) {
        temp[0][i] = f[i];
    }
    for (int i = f.size(); i < lim; i++) {
        temp[0][i] = 0;
    }
    for (int i = 0; i < g.size(); i++) {
        temp[1][i] = g[i];
    }
    for (int i = g.size(); i < lim; i++) {
        temp[1][i] = 0;
    }
    ntt(temp[0], lim, 1);
    ntt(temp[1], lim, 1);
    for (int i = 0; i < lim; i++) {
        temp[0][i] = temp[0][i] * temp[1][i] % mod;
    }
    ntt(temp[0], lim, -1);
    h.clear();
    for (int i = 0; i < len; i++) {
        h.push_back(temp[0][i]);
    }
}
void solve(const vector<ll> &v, int l, int r, vector<ll> &u) {
    if (l == r) {
        u.clear();
        u.push_back(1 - v[l]);
        u.push_back(v[l]);
        return;
    }
    int x = cur++, y = cur++, md = (l + r) >> 1;
    solve(v, l, md, p[x]);
    solve(v, md + 1, r, p[y]);
    mult(p[x], p[y], u);
}
void calc(const vector<ll> &v, vector<ll> &u) {
    if (!v.size()) {
        u.clear();
        u.push_back(1);
    } else {
        solve(v, 0, v.size() - 1, u);
    }
}
int main() {
    scanf("%d %d", &n, &m);
    ll sum = 0, a, p, q;
    for (int i = 1; i <= m; i++) {
        scanf("%lld %lld %lld", &a, &p, &q);
        p = p * mpow(q, mod - 2, mod) % mod;
        v[a].push_back(p), sum = (sum + p) % mod;
    }
    f.push_back(1);
    sum = 2 * sum % mod;
    for (int i = 0; i <= n + 20; i++) {
        g.resize((f.size() + 1) >> 1);
        fill(g.begin(), g.end(), 0);
        for (int i = 0; i < f.size(); i++) {
            g[i >> 1] += f[i];
        }
        calc(v[i], f);
        mult(f, g, f);
        for (int i = 1; i < f.size(); i += 2) {
            sum = (sum - f[i]) % mod;
        }
    }
    printf("%lld\n", (sum + mod) % mod);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42068627/article/details/81145955