AcWing 252. 树|点分治

传送门

题目描述

给定一个有N个点(编号0,1,…,N-1)的树,每条边都有一个权值(不超过1000)。

树上两个节点x与y之间的路径长度就是路径上各条边的权值之和。

求长度不超过K的路径有多少条。

输入格式

输入包含多组测试用例。

每组测试用例的第一行包含两个整数N和K。

接下来N-1行,每行包含三个整数u,v,l,表示节点u与v之间存在一条边,且边的权值为l。

当输入用例N=0,K=0时,表示输入终止,且该用例无需处理。

输出格式

每个测试用例输出一个结果。

每个结果占一行。

数据范围

N10000N≤10000

输入样例:

5 4
0 1 3
0 2 1
0 3 2
2 4 1
0 0

输出样例:

8

题解:点分治

  参考博客:https://www.cnblogs.com/PinkRabbit/p/8593080.html

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e4 + 10;
const int INF = 1<<30;
int n,k,Root,cnt,ans;
int Tsize;       //当前处理的这棵树的节点数
int maxson[N];   //以i为根的最大子树大小
int sz[N];       //以i为根的树的大小
int len[N];      //i到根的边权和
bool vis[N];
vector<pair<int,int> >v[N];
void init() {
    ans = 0;
    for (int i = 0; i < N; ++i) {
        v[i].clear();
        vis[i] = false;
    }
}
//找重心作为根结点
void GetRoot(int u,int fa)  {
    sz[u] = 1;maxson[u] = 0;
    for (int i = 0; i < v[u].size(); i++) {
        pair<int,int> p = v[u][i];
        if (p.first == fa || vis[p.first]) continue;
        GetRoot(p.first,u);
        sz[u]+=sz[p.first];
        maxson[u] = max(maxson[u],sz[p.first]);
    }
    maxson[u] = max(maxson[u],Tsize-sz[u]);
    if (maxson[Root] > maxson[u]) Root = u;
}
void dfs(int u,int fa,int w) {
    len[++cnt] = w;
    for (int i = 0; i < v[u].size(); i++) {
        pair<int,int> p = v[u][i];
        if (p.first == fa || vis[p.first]) continue;
        dfs(p.first,u,w+p.second);
    }
}
int calc(int u,int w) {
    cnt = 0; dfs(u,0,w);
    int l = 1,r = cnt,sum = 0;
    sort(len+1,len+1+cnt);
    for (;;l++) {
        while (r&&len[l]+len[r]>k) r--;
        if (r<l) break;
        sum+=r-l+1;
    }
    return sum;
}
void work(int u) {
    ans+=calc(u,0);vis[u] = 1;
    for (int i = 0; i < v[u].size();++i) {
        pair<int,int> p = v[u][i];
        if ( vis[p.first]) continue;
        ans-=calc(p.first,p.second);
        Root = 0,Tsize = sz[p.first];
        GetRoot(p.first,0);
        work(Root);
    }
}
int main(){
    while (~scanf("%d%d",&n,&k)&&(n||k)) {
        init();
        for (int i = 1; i < n; i++) {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            x++,y++;
            v[x].push_back(make_pair(y,z));
            v[y].push_back(make_pair(x,z));
        }
        Tsize = n;
        Root = 0;
        maxson[0] = INF;
        GetRoot(1,0);
        work(Root);
        printf("%d\n",ans-n);
    }
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/l999q/p/11374040.html