treecnt【组合数学】

题目链接 51nod 1677 treecnt


  这道题被划在了树形dp的框架下,但我更认为是一个组合数学的思维好题,比较容易的想到一点是我们需要求每条边的贡献次数,这个看上去比较的好求,最后总的贡献次数求和就是最后的答案了,别忘了取模。

那么,简单的看,每条边的贡献是这样的:

\LARGE \sum_{i=1}^{K - 1} ( C_{size_u} ^{i} * C_{size_v} ^{K - i} )

那么,这条边的两边各自都要取至少一个点,但是这样的做法的时间复杂度是O(N * K)的,要是心大一点,可以试试。

  那么,这道题的突破点在于哪呢?还是在于这个组合数学的求解方式,我们用反向思维,如果任取K个,然后减去不要这条边的贡献,那么是不是就是这条边的贡献了,哇哦,好巧妙,这样我们直接算(任取K个点的贡献 - 不要这条边的贡献)==(这条边产生的贡献)。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
//#include <unordered_map>
//#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f3f3f3f3f
#define eps 1e-8
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
#define MP(a, b) make_pair(a, b)
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 1e5 + 7;
const ll mod = 1e9 + 7;
inline ll fast_mi(ll a, ll b = mod - 2LL)
{
    ll ans = 1;
    while(b)
    {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1LL;
    }
    return ans;
}
int N, K, head[maxN], cnt;
ll jc[maxN], sum = 0;
struct Eddge
{
    int nex, to;
    Eddge(int a=-1, int b=0):nex(a), to(b) {}
} edge[maxN << 1];
inline void addEddge(int u, int v)
{
    edge[cnt] = Eddge(head[u], v);
    head[u] = cnt++;
}
inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); }
inline ll Calc(ll a, ll b)
{
    if(a < b) return 0;
    return jc[a] * fast_mi(jc[b]) % mod * fast_mi(jc[a - b]) % mod;
}
int siz[maxN];
void dfs(int u, int fa)
{
    siz[u] = 1;
    for(int i=head[u], v; ~i; i=edge[i].nex)
    {
        v = edge[i].to;
        if(v == fa) continue;
        dfs(v, u);
        siz[u] += siz[v];
        sum = (sum + Calc(N, K) - Calc(siz[v], K) - Calc(N - siz[v], K) + mod + mod) % mod;
    }
}
inline void init()
{
    cnt = 0;
    for(int i=1; i<=N; i++) head[i] = -1;
    jc[0] = 1;
    for(ll i=1; i<=N; i++) jc[i] = jc[i - 1] * i % mod;
}
int main()
{
    scanf("%d%d", &N, &K);
    init();
    for(int i=1, u, v; i<N; i++)
    {
        scanf("%d%d", &u, &v);
        _add(u, v);
    }
    dfs(1, 0);
    printf("%lld\n", sum);
    return 0;
}
发布了891 篇原创文章 · 获赞 1066 · 访问量 12万+

猜你喜欢

转载自blog.csdn.net/qq_41730082/article/details/105303070