Codechef Prime Distance On Tree

【传送门】

FFT第四题!

暑假的时候只会点分,然后合并是暴力合并的...水过去了...

其实两条路径长度的合并就是卷积的过程嘛,每次统计完路径就自卷积一下。

刚开始卷积固定了值域。T了。然后就不偷懒了,每次取最大权值乘二去找值域了。

#include <bits/stdc++.h>

const double pi = acos(-1.0);

struct Complex {
    double r, i;
    void clear() { r = i = 0.0; }
    Complex(double r = 0, double i = 0): r(r), i(i) {}
    Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
    Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
    Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
};

void FFT(Complex *a, int n, int pd, int *r) {
    for (int i = 0; i < n; i++)
        if (i < r[i])
            std::swap(a[i], a[r[i]]);
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex wn(cos(pi / mid), pd * sin(pi / mid));
        for (int l = mid << 1, j = 0; j < n; j += l) {
            Complex w(1.0, 0.0);
            for (int k = 0; k < mid; k++, w = w * wn) {
                Complex u = a[k + j], v = w * a[k + j + mid];
                a[k + j] = u + v;
                a[k + j + mid] = u - v;
            }
        }
    }
    if (pd < 0)
        for (int i = 0; i < n; i++)
            a[i] = Complex(a[i].r / n, a[i].i / n);
}

#define ll long long

const int N = 2e5 + 7;
int n, sz[N], maxsz[N], root, totsz;
std::vector<int> vec[N];
int prime[N], prin;
bool vis[N], is[N];
ll cnt[N], ccnt[N];
int dis[N], r[N];
Complex A[N];
int limit, l;

void init() {
    for (int i = 2; i < N; i++) {
        if (!is[i]) prime[++prin] = i;
        for (int j = 1; j <= prin && i * prime[j] < N; j++) {
            is[i * prime[j]] = 1;
            if (i % prime[j] == 0) break;
        }
    }
}

inline bool chkmax(int &a, int b) { return a < b ? a = b, 1 : 0; }

void getroot(int u, int fa) {
    sz[u] = 1; maxsz[u] = 0;
    for (int v : vec[u]) {
        if (v == fa || vis[v]) continue;
        getroot(v, u);
        sz[u] += sz[v];
        chkmax(maxsz[u], sz[v]);
    }
    chkmax(maxsz[u], totsz - sz[u]);
    if (maxsz[u] < maxsz[root]) root = u;
}

int f[N], tto, val;

void getdis(int u, int fa) {
    f[++tto] = dis[u];
    val = std::max(val, f[tto]);
    for (int v : vec[u]) {
        if (vis[v] || v == fa) continue;
        dis[v] = dis[u] + 1;
        getdis(v, u);
    }
}

void cal(int u, int d, int opt) {
    tto = 0;
    dis[u] = d;
    val = 0;
    getdis(u, 0);
    for (int i = 1; i <= tto; i++)
        ccnt[f[i]]++;
    limit = 1, l = 0;
    while (limit <= 2 * val)
        limit <<= 1, l++;
    for (int i = 0; i < limit; i++)
        r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
    for (int i = 0; i < limit; i++)
        A[i] = Complex((double)ccnt[i], 0.0);
    FFT(A, limit, 1, r);
    for (int i = 0; i < limit; i++)
        A[i] = A[i] * A[i];
    FFT(A, limit, -1, r);
    for (int i = 1; i < limit; i++)
        cnt[i] += opt * (ll)(A[i].r + 0.5);
    for (int i = 1; i <= tto; i++)
        ccnt[f[i]]--;
}

void solve(int u) {
    vis[u] = 1;
    cal(u, 0, 1);
    for (int v : vec[u]) {
        if (vis[v]) continue;
        cal(v, 1, -1);
        totsz = sz[v];
        root = 0;
        getroot(v, 0);
        solve(root);
    }
}

int main() {
    init();
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    maxsz[root = 0] = n;
    totsz = n;
    getroot(1, 0);
    solve(root);
    ll ans = 0;
    for (int i = 1; i <= prin; i++) {
        ans += cnt[prime[i]];
    }
    ll sum = 1LL * n * (n - 1);
    printf("%.7f\n", 1.0 * ans / sum);
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Mrzdtz220/p/11886655.html
今日推荐