Tree 点分治

题目描述

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

输入输出格式

输入格式:

N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k

输出格式:

一行,有多少对点之间的距离小于等于k

输入输出样例

输入样例#1: 
7
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10
输出样例#1: 
5





#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read() {
    int res=0;char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) res=(res<<3)+(res<<1)+(ch^48),ch=getchar();
    return res;
}
#define reg register
#define N 200005
int n, tot, k;
long long ans;

int head[N], cnt = 1;
struct edge {
    int nxt, to, val;
}ed[N*2];
inline void add(int x, int y, int z)
{
    ed[++cnt] = (edge){head[x], y, z};
    head[x] = cnt;
}

bool cut[N*2];


int siz[N], root, mrt = 1e9;
void dfs(int x, int fa)
{
    siz[x] = 1;
    for (reg int i = head[x] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (to == fa or cut[i]) continue;
        dfs(to, x);
        siz[x] += siz[to];
    }    
}    
void efs(int x, int fa)
{
    int tmp = tot - siz[x];
    for (reg int i = head[x] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (to == fa or cut[i]) continue;
        efs(to, x);
        tmp = max(tmp, siz[to]);
    }
    if (tmp < mrt) mrt = tmp, root = x;
}    
inline int FindRoot(int x)
{
    return x;
    dfs(x, 0);
    mrt = 1e9;
    tot = siz[x];
    root = n;
    efs(x, 0);
    return root;
}


int a[N];
int top;
void Work(int x, int fa, int d)
{
    a[++top] = d;
    for (reg int i = head[x] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (to == fa or cut[i]) continue;
        Work(to, x, d + ed[i].val);
    }
}
inline int Calc(int x, int d)
{
    int res = 0;
    top = 0;
    Work(x, 0, d);
    sort (a + 1, a + 1 + top);
    int l = 1, r = top;
    while (l < r)
    {
        while(a[l] + a[r] > k and l < r) r--;
        res += r - l, l ++;
    }        
    return res;
}
void solve(int rt)
{
    root = FindRoot(rt);
    ans += Calc(root, 0);
    for (reg int i = head[root] ; i ; i = ed[i].nxt)
    {
        int to = ed[i].to;
        if (cut[i]) continue;
        cut[i] = cut[i ^ 1] = 1;
        ans -= Calc(to, ed[i].val);
        solve(to);
    }
}
    

int main()
{
    n = read();
    for (reg int i = 1 ; i < n ; i ++)
    {
        int x = read(), y = read(), z = read();
        add(x, y, z), add(y, x, z);
    } k = read();
    solve(1);
    printf("%d\n", ans);
    return 0;
}


猜你喜欢

转载自www.cnblogs.com/BriMon/p/9479963.html
今日推荐