2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest G. Rikka with Intersections of Paths(树上差分+LCA+容斥)

题目链接:http://codeforces.com/gym/102012/problem/G

题目大意:有一棵n个结点的树,现在给出m条树上的路径。现在要从这m条路径中选出k条路径,使得这k条路径至少有一个公共交点,问你总共有多少种方案数。

题目思路:(今年徐州现场的银牌题,我们队肝到最后也没能肝出来,错失了银牌。。。QAQ,当时忘了一个重要的性质,导致正思路都错了。还是太菜了)

感慨一下,继续分析题目。

解决这个题,需要用到一个重要的性质:一个树上任意两条路径如果有交点的话,那么这些交点中肯定有一个为两条路径中的一条路径两端点的lca

有了这个性质的话,我们可以对通过枚举路径的交点来求答案。

对于每个节点,我们假设通过这个节点的路径有M条,以这个点为LCA且通过这个节点的路径有N条。

那么在这个节点对答案的贡献为:C_{M}^{K}-C_{M-N}^{K}。这个式子计算出来的是,从通过这个节点的路径中选出k条路径,且至少有一条路径的LCA为这个节点的方案数,这样选的话就不会出现重复选的情况了,因为至少有一条路径以该节点为LCA,在以其他点为交点的时候就不会重复计算了。

而通过某个结点的路径数我们可以通过树上差分计算,假设通过u这个节点的路径为sum[u]。那么在更新路径[u,v]的时候,我们就令sum[u]++,sum[v]++,sum[lca(u,v)]--,sum[fa[lca(u,v)]]--。接着再用dfs一遍即可。

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int MX = 3e5 + 5;
const int mod = 1e9 + 7;

int n, m, k;
struct edge {int v, w, nxt;} E[MX << 1];
int head[MX], tot;
int dep[MX], ST[MX][20];
void add_edge(int u, int v) {
    E[tot].v = v; E[tot].nxt = head[u];
    head[u] = tot++;
}
void dfs(int u, int d, int fa) {
    dep[u] = d; ST[u][0] = fa;
    for (int i = head[u]; ~i; i = E[i].nxt) {
        int v = E[i].v;
        if (v == fa) continue;
        dfs(v, d + 1, u);
    }
}
void pre_solve() {
    dfs(1, 0, 1);
    for (int i = 1; i < 20; i++) {
        for (int j = 1; j <= n; j++) {
            ST[j][i] = ST[ST[j][i - 1]][i - 1];
        }
    }
}
int LCA(int u, int v) {
    while (dep[u] != dep[v]) {
        if (dep[u] < dep[v]) swap(u, v);
        int d = dep[u] - dep[v];
        for (int i = 0; i < 20; i++)
            if (d >> i & 1)u = ST[u][i];
    }
    if (u == v) return u;
    for (int i = 19; i >= 0; i--) {
        if (ST[u][i] != ST[v][i]) {
            u = ST[u][i];
            v = ST[v][i];
        }
    }
    return ST[u][0];
}
int sum[MX], lca_sum[MX];
void solve(int u, int fa) {
    for (int i = head[u]; ~i; i = E[i].nxt) {
        int v = E[i].v;
        if (v == fa) continue;
        solve(v, u);
        sum[u] += sum[v];
    }
}

ll f[MX], inv[MX];
ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1) res = (res * a) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return res;
}
void init() {
    f[1] = 1;
    for (int i = 2; i < MX; i++) f[i] = (f[i - 1] * i) % mod;
    inv[MX - 1] = qpow(f[MX - 1], mod - 2);
    for (int i = MX - 2; i >= 1; i--) inv[i] = (inv[i + 1] * (i + 1)) % mod;
}
ll C(int n, int m) {
    if (n < 0 || m < 0 || m > n) return 0;
    if (m == 0 || m == n) return 1;
    return f[n] * inv[n - m] % mod * inv[m] % mod;
}

int main() {
    // FIN;
    init();
    int T; cin >> T;
    while (T--) {
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 1; i <= n; i++) head[i] = -1;
        tot = 0;
        for (int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            add_edge(u, v); add_edge(v, u);
        }
        pre_solve();
        for (int i = 1; i <= m; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            int lca = LCA(u, v); lca_sum[lca]++;
            sum[u]++; sum[v]++;
            sum[lca]--;
            if (lca != 1) sum[ST[lca][0]]--;
        }
        solve(1, 0);
        ll ans = 0;
        for (int i = 1; i <= n; i++)
            ans = (ans % mod + ((C(sum[i], k) - C(sum[i] - lca_sum[i], k)) % mod + mod) % mod) % mod;
        printf("%lld\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Lee_w_j__/article/details/84780981