【JZOJ6385】【NOIP2019模拟2019.10.23】B

题目大意

给出一棵树,其中\(1\)为根。之后每个点向父亲的父亲再连一条边,求得到的图中,每个点走到\(1\)的期望步数(等概率向相邻点走去)。
保证\(i\)的父亲\(fa_i<i\)
\(n\leq 2000\)

Solution

首先列方程,设\(f_i\)表示\(i\)走向\(1\)的期望步数,有:

\[f_x=1+\frac{1}{d_x}\sum f_y\]

其中,\(d_x\)\(x\)的度数,\(y\)是与\(x\)相邻的所有点。

直接高斯消元解方程,复杂度\(O(n^3)\),过不了。

观察这个系数矩阵的特点,第\(x\)行的系数只会在父亲,父亲的父亲,儿子,儿子的儿子处有值。如果我们从儿子往根消元,每次用第\(x\)行消去\(fa_x\)行和\(fa_{fa_x}\)行的第\(x\)列,那么最后会得到一个下三角矩阵。这时第\(x\)行的系数只会在父亲,父亲的父亲处有值,我们从根往儿子消元,就能得到对角线矩阵了。

复杂度\(O(n^2)\)。这个做法利用了系数矩阵的特点,减少消元次数,真是妙不可言~~~

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 2007, P = 998244353;

int n, fa[N], d[N], c[N][N], b[N];

int tot, st[N], nx[N << 2], to[N << 2];
void add(int u, int v) {
    to[++tot] = v, nx[tot] = st[u], st[u] = tot;
    to[++tot] = u, nx[tot] = st[v], st[v] = tot;
    ++d[u], ++d[v];
}

int pow(int a, int b) {
    int ret = 1;
    for (; b; a = 1ll * a * a % P, b >>= 1) if (b & 1) ret = 1ll * ret * a % P;
    return ret;
}

int main() {
    //freopen("in", "r", stdin);
    freopen("b.in", "r", stdin);
    freopen("b.out", "w", stdout);
    scanf("%d", &n);
    for (int i = 2; i <= n; ++i) scanf("%d", &fa[i]), add(i, fa[i]);
    for (int i = 1; i <= n; ++i) if (fa[fa[i]]) add(i, fa[fa[i]]);
    c[1][1] = 1;
    for (int i = 2; i <= n; ++i) {
        c[i][i] = 1;
        for (int j = st[i]; j; j = nx[j]) c[i][to[j]] = P - pow(d[i], P - 2);
        b[i] = 1;
    }
    for (int i = n; i >= 1; --i) {
        if (fa[i]) {
            int j = fa[i], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
            for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
            b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
        }
        if (fa[fa[i]]) {
            int j = fa[fa[i]], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
            for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
            b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
        }
    }
    for (int i = 1; i <= n; ++i) {
        for (int l = st[i]; l; l = nx[l]) if (to[l] != fa[i] && to[l] != fa[fa[i]]) {
            int j = to[l], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
            for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
            b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
        }
    }
    for (int i = 1; i <= n; ++i) printf("%d\n", 1ll * b[i] * pow(c[i][i], P - 2) % P);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zjlcnblogs/p/11733396.html
今日推荐