Codeforces 1551F 枚举 + DP

题意

传送门 Codeforces 1551F Equidistant Vertices

题解

K = 2 K=2 K=2,则任一一对节点都满足条件,答案为 N ( N − 1 ) / 2 N(N-1)/2 N(N1)/2

K ≥ 3 K\geq 3 K3,容易证明,若节点 u u u 被选择,那么其余节点只可能同时处于 u u u 某条连边的某一侧。那么,被选择的节点 u u u 不可能位于任一对被选择的节点 v , w , v ≠ u , w ≠ u v,w,v\neq u,w\neq u v,w,v=u,w=u 间的路径上。容易证明,被选择的节点集 S S S 所构成的生成树,只可能存在一个度大于 2 2 2 的节点。那么枚举这样的节点 u u u,此时,对于与 u u u 连边的任一节点 v v v,以 v v v 为根的子树上至多选择一个节点。

那么可以以当前枚举节点 u u u 为根进行 D F S DFS DFS,预处理出各子树上距离 u u u d d d 的节点数量。枚举路径长度,问题转化为,已知集合 S 0 , S 2 ⋯   , S m − 1 S_0,S_2\cdots,S_{m-1} S0,S2,Sm1,从任一集合至多选择一个元素,总共选择 K K K 个元素的方案数为多少。可以使用 D P DP DP 求解。 d p [ i + 1 ] [ j + 1 ] dp[i+1][j+1] dp[i+1][j+1] 代表从 0 ⋯ i 0\cdots i 0i i + 1 i+1 i+1 个集合中选取 j + 1 j+1 j+1 个元素的方案数,那么有
d p [ i + 1 ] [ j + 1 ] = d p [ i ] [ j + 1 ] + d p [ i ] [ j ] ∗ ∣ S i ∣ dp[i+1][j+1]=dp[i][j+1]+dp[i][j]*\lvert S_i\rvert dp[i+1][j+1]=dp[i][j+1]+dp[i][j]Si 总时间复杂度 O ( T N 2 K ) O(TN^2K) O(TN2K)

#include <bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for (int i = l, _ = r; i < _; ++i)
typedef long long ll;
const int maxn = 105, mod = 1000000007;
int T, N, K, Res, dp[maxn][maxn];
int cnt[maxn], mem[maxn];
vector<int> G[maxn], A[maxn];

void dfs(int u, int p, int d)
{
    
    
    ++cnt[d];
    for (auto &v : G[u])
        if (v != p)
            dfs(v, u, d + 1);
}

void add(int u)
{
    
    
    for (auto &v : G[u])
    {
    
    
        memcpy(mem, cnt, sizeof(cnt));
        dfs(v, u, 1);
        rep(i, 1, N)
        {
    
    
            if (mem[i] == cnt[i])
                break;
            A[i].push_back(cnt[i] - mem[i]);
        }
    }
    rep(x, 0, N)
    {
    
    
        if ((int)A[x].size() < K)
            continue;
        int n = A[x].size();
        rep(j, 0, K + 1) dp[0][j] = 0;
        rep(i, 0, n + 1) dp[i][0] = 1;
        rep(i, 0, n) rep(j, 0, K) dp[i + 1][j + 1] = (dp[i][j + 1] + (ll)dp[i][j] * A[x][i] % mod) % mod;
        Res = (Res + dp[n][K]) % mod;
    }
}

int main()
{
    
    
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> T;
    while (T--)
    {
    
    
        cin >> N >> K;
        rep(i, 0, N) G[i].clear();
        rep(i, 1, N)
        {
    
    
            int u, v;
            cin >> u >> v;
            --u, --v;
            G[u].push_back(v), G[v].push_back(u);
        }
        if (K == 2)
        {
    
    
            cout << N * (N - 1) / 2 << '\n';
            continue;
        }
        Res = 0;
        rep(i, 0, N)
        {
    
    
            memset(cnt, 0, sizeof(cnt));
            rep(j, 0, N) A[j].clear();
            add(i);
        }
        cout << Res << '\n';
    }
    return 0;
}

おすすめ

転載: blog.csdn.net/neweryyy/article/details/120399175