树上距离(树形dp)
题目描述
懒惰的温温今天上班也在偷懒。盯着窗外发呆的温温发现,透过窗户正巧能看到一棵n个节点的树。一棵n个节点的树包含n-1条边,且n个节点是联通的。树上两点之间的距离即两点之间的最短路径包含的边数。
突发奇想的温温想要知道,树上有多少个不同的点对,满足两点之间的距离恰好等于k。
注意:(u, v)和(v, u)视作同一个点对,只计算一次答案。
突发奇想的温温想要知道,树上有多少个不同的点对,满足两点之间的距离恰好等于k。
注意:(u, v)和(v, u)视作同一个点对,只计算一次答案。
输入
第一行两个整数n和k。
接下来n-1行每行两个整数ai, bi,表示节点ai和bi之间存在一条边。
1 ≤ k ≤ 500
2 ≤ n ≤ 500 for 40%
2 ≤ n ≤ 50000 for 100%
接下来n-1行每行两个整数ai, bi,表示节点ai和bi之间存在一条边。
1 ≤ k ≤ 500
2 ≤ n ≤ 500 for 40%
2 ≤ n ≤ 50000 for 100%
输出
输出一个整数,表示满足条件的点对数量。
样例输入 Copy
【样例1】
5 2
1 2
2 3
3 4
2 5
【样例2】
5 3
1 2
2 3
3 4
4 5
样例输出 Copy
【样例1】
4
【样例2】
2
思路:树形dp.dp[i][j]表示第i个点,距离为j的点有多少个。具体看代码
#pragma comment(linker, "/STACK:1024000000,1024000000")
#pragma GCC optimize(3,"Ofast","inline")
#include <bits/stdc++.h>
using namespace std;
#define rep(i , a , b) for(register int i=(a);i<=(b);i++)
#define per(i , a , b) for(register int i=(a);i>=(b);i--)
#define ms(s) memset(s, 0, sizeof(s))
#define squ(x) (x)*(x)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll , ll> pi;
typedef unordered_map<int,int> un_map;
template<class T>
inline void read (T &x) {
x = 0;
int sign = 1;
char c = getchar ();
while (c < '0' || c > '9') {
if ( c == '-' ) sign = - 1;
c = getchar ();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar ();
}
x = x * sign;
}
const int maxn = 1e5+10;
const int inf = 0x3f3f3f3f;
const ll INF = ll(1e18);
const int mod = 1e9+7;
const double PI = acos(-1);
//#define LOCAL
int n,k;
std::vector<int> e[50005];
int dp[50005][505];
ll ans;
void dfs(int x,int fa) {
dp[x][0]=1;
int siz = e[x].size();
rep(i,0,siz-1) {
int y = e[x][i];
if(y==fa) continue;
dfs(y,x);
rep(j,0,k-1) {
ans=ans+dp[x][j]*dp[y][k-j-1];
if(j>0) dp[x][j]+=dp[y][j-1];
}
}
}
int main(int argc, char * argv[])
{
#ifdef LOCAL
freopen("A.in", "r", stdin);
//freopen("A.out", "w", stdout);
#endif
while (~scanf("%d%d", &n,&k))
{
ms(dp);
rep(i,1,n-1) {
int u,v;
read(u);read(v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,-1);
printf("%lld\n",ans);
rep(i,1,n) e[i].clear();
}
return 0;
}