点分治还是很有意思的,不过自己看博客还是比较吃力的,看了好久。
基本思想:
找一颗树上的重心(所有以此点为根的子树中最大子树最小的那个根结点)
分治解决。(这是关键。)
详细解释一下:一开始结点数为n,朴素算法O(
),若找重心分治处理就是O(
)层,每次合并时间是(这里是O(nlog(n)+n) 排序+找最远可行对)不定。
那最终就是O(n*log(n)*log(n)).
最终的决定权基本在合并算法的手里,能合并,就能处理。
需要注意的是
ans-=solve(v,w);
要减掉没有经过根的点
#include<iostream>
#include<stdio.h>
#include<string>
#include<cstring>
#include<vector>
#include<stack>
#include<algorithm>
#include<queue>
using namespace std;
#define clr(a, x) memset(a, x, sizeof(a))
#define mp(x, y) make_pair(x, y)
#define pb(x) push_back(x)
#define X first
#define Y second
#define fastin \
ios_base::sync_with_stdio(0); \
cin.tie(0);
typedef long long LL;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF=0x3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-6;
const int N=10010;
const int M=40010;
int head[N];
int top;
struct Edge
{
int val;
int to;
int next;
}edge[M];
void addedge(int a,int b,int c)
{
edge[top]=Edge{c,b,head[a]};
head[a]=top++;
}
void init()
{
top=0;
memset(head,-1,sizeof(head));
}
int k;
int root,sim[N],S,mxson[N];
int MX,vis[N];
void getroot(int u,int fa){
sim[u]=1;mxson[u]=0;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;
int w=edge[i].val;
if(v==fa||vis[v]) continue;
getroot(v,u);
sim[u]+=sim[v];
mxson[u]=max(mxson[u],sim[v]);
}
mxson[u] = max(mxson[u],S-sim[u]);
if(mxson[u]<MX){
MX=mxson[u];
root=u;
}
}
LL ans;
int id;
int dis[N]; //到重心的距离
void getdis(int u,int fa,int dist) //点u到root的距离
{
dis[id++] = dist;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;int w=edge[i].val;
if(v==fa||vis[v]) continue;
getdis(v,u,dist+w);
}
}
int solve(int u,int len){ //排序找符合的点
id=0;
clr(dis,0);
getdis(u,0,len);
sort(dis,dis+id);
int L=0,R=id-1;
int tmp=0;
while(L<R){
if(dis[R]+dis[L]<=k){tmp+=R-L;L++;}
else R--;
}
return tmp;
}
void Divide(int u){
ans+=solve(u,0);
vis[u]=true;
for(int i=head[u];~i;i=edge[i].next){
int v=edge[i].to;int w=edge[i].val;
if(vis[v]) continue;
ans-=solve(v,w);
S=sim[v];root=0;
MX=INF;getroot(v,0);
Divide(root);
}
}
int main()
{
int n;
while(scanf("%d%d",&n,&k),n+k){
init();
int a,b,c;
for(int i=1;i<n;i++){
scanf("%d%d%d",&a,&b,&c);
addedge(a,b,c);
addedge(b,a,c);
}
clr(vis,0);
S=n,MX=INF;
root=0;ans=0;
getroot(1,0);
Divide(root);
printf("%lld\n",ans);
}
return 0;
}