@atcoder - AGC034F@ RNG and XOR


@description@

给定一个值域在 [0, 2^N) 的随机数生成器,给定参数 A[0...2^N-1]。
该生成器有 \(\frac{A_i}{\sum A}\) 的概率生成 i,每次生成都是独立的。

现在有一个 X,初始为 0。每次操作生成一个随机数 v 并将 X 异或 v。

对于每一个 i ∈ [0, 2^N),求期望多少次操作 X 第一次等于 i。

原题题面

@solution@

不难想到期望 dp。定义 dp[i] 表示到达 i 的期望次数,则:

\[ dp[0] = 0 \\ dp[i] = (\sum_{j=0}^{2^N - 1}dp[j]\times p[i\oplus j]) + 1 \]

其中 \(p[i] = \frac{A_i}{\sum A}\)

朴素做法是高斯消元。显然过不了。
对于高斯消元的常规优化是利用转移的图结构(比如 DAG,链或者树),但是这个题转移的图是完全图,做不到。

怎么办?观察转移式的结构,发现它其实是异或卷积。于是我们尝试走生成函数那一套。

如果用生成函数的记法,又可以将其记作 \(dp\oplus P + I = dp + k\times T\),其中 \(I[i] = 1, T[i] = [i = 0]\)\(k\) 是一个未知数。
注意当 n = 0 卷积是不成立的,所以需要在末尾填上一项 \(k\times T\)

变一下形得到 \(dp\oplus (T - P) = I - k\times T\),两边同时进行 fwt 得到 \(dp'\times (T - P)' = I' - k\times T'\)

注意到 \((T - P)'\) 的第 0 项始终为 0(根据 fwt 的定义可知),故 \(I' - k\times T'\) 的第 0 项也为 0,由此可以解出 k。

但是这样一来我们又不知道 \(dp'[0]\) 的值为多少,再次设未知数为 q。进行逆变换时把未知数代进去一起运算就可以了。

然后 \(dp\) 数列就可以表示成含 q 的一次函数,而根据 \(dp[0] = 0\) 可以反解出 q,于是 \(dp\) 数列就解出来了。

@accepted code@

#include <cstdio>

const int MOD = 998244353;
const int INV2 = (MOD + 1) >> 1;

int add(int x, int y) {return (x + y >= MOD ? x + y - MOD : x + y);}
int sub(int x, int y) {return (x - y < 0 ? x - y + MOD : x - y);}
int mul(int x, int y) {return 1LL*x*y%MOD;}

int pow_mod(int b, int p) {
    int ret = 1;
    for(int i=p;i;i>>=1,b=mul(b,b))
        if( i & 1 ) ret = mul(ret,b);
    return ret;
}

struct node{
    int k, b;
    node() : k(0), b(0) {}
    node(int _k, int _b) : k(_k), b(_b) {}
    int get(int x) {return add(mul(k, x), b);}
    friend node operator + (node a, node b) {
        return node(add(a.k, b.k), add(a.b, b.b));
    }
    friend node operator - (node a, node b) {
        return node(sub(a.k, b.k), sub(a.b, b.b));
    }
    friend node operator * (node a, int k) {
        return node(mul(a.k, k), mul(a.b, k));
    }
    friend node operator / (node a, int k) {
        return a * pow_mod(k, MOD - 2);
    }
};

void fwt(node *A, int m, int type) {
    int n = (1 << m), f = (type == 1 ? 1 : INV2);
    for(int i=1;i<=m;i++) {
        int s = (1 << i), t = (s >> 1);
        for(int j=0;j<n;j+=s)
            for(int k=0;k<t;k++) {
                node x = A[j+k], y = A[j+k+t];
                A[j+k] = (x + y)*f, A[j+k+t] = (x - y)*f;
            }
    }
}

node A[1<<18], B[1<<18], C[1<<18], f[1<<18];

int main() {
    int N, M, S = 0; scanf("%d", &N), M = (1 << N);
    for(int i=0;i<M;i++) scanf("%d", &A[i].b), S = add(S, A[i].b);
    S = pow_mod(S, MOD - 2);
    for(int i=0;i<M;i++) A[i].b = sub(i == 0 ? 1 : 0, mul(A[i].b, S));
    for(int i=0;i<M;i++) B[i].b = 1;
    C[0].b = MOD - 1;
    fwt(A, N, 1), fwt(B, N, 1), fwt(C, N, 1);
    int tmp = mul(B[0].b, pow_mod(C[0].b, MOD-2));
    for(int i=1;i<M;i++)
        f[i] = (B[i] - C[i]*tmp) / A[i].b;
    f[0].k = 1; fwt(f, N, -1);
    int x = sub(0, mul(pow_mod(f[0].k, MOD-2), f[0].b));
    for(int i=0;i<M;i++) printf("%d\n", f[i].get(x));
}

@details@

感觉我的做法很像是乱搞。。。不过我也不大清楚官方正解是啥子。。。

猜你喜欢

转载自www.cnblogs.com/Tiw-Air-OAO/p/12142184.html
今日推荐