2019 ICPC 徐州网络赛 J.Random Access Iterator

2019 ICPC 徐州网络赛 J.Random Access Iterator


题目大意:给你n个点和n-1条边(树形结构),保证1为根节点,通过以下方式dfs遍历:
在这里插入图片描述
询问dfs到最深节点的概率(有多个最深节点则任意一个即可),答案对1e9+7取模。

解法:比赛的时候最后一个小时开的这道概率题,最后10分钟AC了。看上去有些困难,其实就是一个dp的过程。先一遍dfs找到每个节点的深度和孩子数,然后令最深节点的dp值为1.(dp[i]表示dfs到i节点之后成功dfs到最深点的概率)接下来从最下面开始更新dp: 对于任意一个节点u,设它有x个孩子,分别的dp值是dp[v1],dp[v2],dp[v3]…dp[vx];有规则可知dfs到u节点后,会重复进行x次dfs,直接求x次后成功的概率比较困难,考虑反面,求x次全部失败的概率。x次事件独立,所以求一次失败概率再x次方即可。一次失败很好求,就是(1-dp[v1])*(1-dp[v2])…(1-dp[vx])/x。

下面是AC代码:

#include <bits/stdc++.h>
using namespace std;
using namespace chrono;
const int N = 2000005;
const int M = 1000000007;
const int INF = 0x3f3f3f3f;
const double PI = acos(-1);
const double eps = 1e-8;
#define ms(x, y) memset((x), (y), sizeof(x))
#define mc(x, y) memcpy((x), (y), sizeof(y))
typedef long long ll;
typedef unsigned long long ull;
#define fi first
#define se second
#define mp make_pair
typedef pair<int, int> pii;
typedef pair<ll, int> pli;
#define bg begin
#define ed end
#define pb push_back
#define al(x) (x).bg(), (x).ed()
#define st(x) sort(al(x))
#define un(x) (x).erase(unique(al(x)), (x).ed())
#define fd(x, y) (lower_bound(al(x), (y)) - (x).bg() + 1)
#define ls(x) ((x) << 1)
#define rs(x) (ls(x) | 1)
template <class T>
bool read(T & x) {
    
    
    char c;
    while (!isdigit(c = getchar()) && c != '-' && c != EOF);
    if (c == EOF) return false;
    T flag = 1;
    if (c == '-') {
    
     flag = -1; x = 0; } else x = c - '0';
    while (isdigit(c = getchar())) x = x * 10 + c - '0';
    x *= flag;
    return true;
}
template <class T, class ...R>
bool read(T & a, R & ...b) {
    
    
    if (!read(a)) return false;
    return read(b...);
}
mt19937 gen(steady_clock::now().time_since_epoch().count());
struct edge {
    
     int to, next; } e[N];
int head[N], cnt = 0, sz[N], dep[N], leaf[N];
ll dp[N];
ll qpow(ll a, ll n, ll p) {
    
    
    ll r = 1;
    for (a %= p; n; n >>= 1, (a *= a) %= p)
        if (n & 1) (r *= a) %= p;
    return r;
}
void add(int u, int v) {
    
    
    e[cnt] = {
    
    v, head[u]};
    head[u] = cnt++;
}
void dfs(int d, int u, int p) {
    
    
    dep[u] = d;
    for (int i = head[u]; ~i; i = e[i].next) {
    
    
        int v = e[i].to;
        if (v == p) continue;
        sz[u]++;
        dfs(d + 1, v, u);
    }
    if (sz[u] == 0) leaf[u] = 1;
}
void dfs2(int u, int p) {
    
    
    if (leaf[u]) return;
    for (int i = head[u]; ~i; i = e[i].next) {
    
    
        int v = e[i].to;
        if (v == p) continue;
        dfs2(v, u);
        dp[u] = (dp[u] + (1 - dp[v] + M) % M) % M;
    }
    dp[u] = dp[u] * qpow(sz[u], M - 2, M) % M;
    dp[u] = qpow(dp[u], sz[u], M);
    // if (u == 1)
    //     cout << "check: " << dp[u] << ' ' << sz[u] << endl;
    dp[u] = (1 - dp[u] + M) % M;
    // if (u == 1)
    //     cout << "check: " << dp[u] << ' ' << sz[u] << endl;
}
int main()
{
    
    
    time_point<steady_clock> start = steady_clock::now();

    int size = 128 << 20;
    char * p = (char *)malloc(size) + size;
    #if (defined _WIN64) or (defined __unix)
        __asm__("movq %0, %%rsp\n" :: "r"(p));
    #else
        __asm__("movl %0, %%esp\n" :: "r"(p));
    #endif

    // cout << 39 * qpow(64, M - 2, M) % M << endl;
    // cout << 25 * qpow(64, M - 2, M) % M << endl;

    int n, u, v;
    read(n);
    ms(head, -1);
    for (int i = 1; i < n; i++) {
    
    
        read(u, v);
        add(u, v);
        add(v, u);
    }
    dfs(1, 1, 0);
    int mx = 0;
    for (int i = 1; i <= n; i++)
        if (leaf[i]) mx = max(mx, dep[i]);
    for (int i = 1; i <= n; i++) {
    
    
        if (leaf[i]) {
    
    
            if (dep[i] == mx) {
    
    
                dp[i] = 1;
            } else dp[i] = 0;
        }
    }
    dfs2(1, 0);
    // cout << "------------------" << endl;
    // for (int i = 1; i <= n; i++) cout << i << ' ' << dp[i] << endl;
    // cout << "------------------" << endl;
    printf("%lld\n", dp[1]);

    cerr << endl << "------------------------------" << endl << "Time: "
         << duration<double, milli>(steady_clock::now() - start).count()
         << " ms." << endl;

    exit(0);
}

猜你喜欢

转载自blog.csdn.net/yzsjwd/article/details/100608726