楼教主 男人八题之一 Poj 1741 Tree

link
树上距离小于等于k点对对数
点分治模板
首先考虑经过某一点zx的对数,那么我们可以通过dfs处理出所有点到zx的距离后求dis [a]+dis[b]<=k的个数(排序后双指针),但是这样算的时候还有可能两点没过zx,所以要容斥一下(减去所有以儿子节点为根的数量即为不经过zx的数量)。
如果随便枚举zx的话,最坏复杂度会达到O(n^2)
所以我们让每次枚举的点为重心,复杂度为O(nlog n) (每次都会减少大于s/2个节点)

#include<bits/stdc++.h>
#define ls rt<<1
#define rs rt<<1|1
#define fi first
#define se second
#define pb push_back
using namespace std;
typedef long long ll;
typedef vector<int> VI;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const  ll inf  = 0x3f3f3f3f3f3f3f3f;
const int mod  = 998244353;
const int maxn = 1e6 + 4;
const int N    = 5e3 + 5;
const double eps = 1e-6;
const double pi = acos(-1.0);
ll qpow(ll x,ll y,ll mod){
    
    ll ans=1;x%=mod;while(y){
    
     if(y&1) ans=ans*x%mod; x=x*x%mod; y>>=1;}return ans;}
//#define LOCAL
struct node{
    
    int v,w; };
int n,k,ans,zx,mn,tot,vis[maxn],sz[maxn],dis[maxn];
vector<node>G[maxn];
void dfs1(int u,int fa) {
    
    
    sz[u]=1;
    for(auto v:G[u]) if(v.v^fa&&!vis[v.v]) {
    
    
        dfs1(v.v,u);sz[u]+=sz[v.v];
    }
}
void dfs2(int u,int fa,int nn) {
    
    
    int mx=nn-sz[u];
    for(auto v:G[u]) if(v.v^fa&&!vis[v.v]) {
    
    
        dfs2(v.v,u,nn);mx=max(mx,sz[v.v]);
    }
    if(mx<mn) mn=mx,zx=u;
}
void dfs3(int u,int fa,int d) {
    
    
    dis[++tot]=d;
    for(auto v:G[u]) if(v.v^fa&&!vis[v.v]) {
    
    
        dfs3(v.v,u,d+v.w);
    }
}
int kk(int u,int fa,int d) {
    
    
    tot=0;dfs3(u,fa,d);
    sort(dis+1,dis+tot+1);
    int ans=0,i=1,j=tot;
    while(i<j) {
    
    
        if(dis[i]+dis[j]<=k) ans+=j-i,i++;
        else j--;
    }
    return ans;
}
void dfs(int u ,int fa) {
    
    
    mn=INT_MAX;
    dfs1(u,fa);dfs2(u,fa,sz[u]);
    vis[zx]=1;ans+=kk(zx,0,0);
    int t=zx;
    for(auto v:G[t]) if(v.v^fa&&!vis[v.v]){
    
    
        ans-=kk(v.v,t,v.w);
        dfs(v.v,t);
    }
}
void init(int n) {
    
    
    for(int i=0;i<=n;i++) G[i].clear(),vis[i]=0;
}
int main() {
    
    
#ifdef LOCAL
    freopen("RACE input.in","r",stdin);
#endif
    while(scanf("%d%d",&n,&k)&&n) {
    
    
        init(n);
        for(int i=1;i<n;i++) {
    
    
            int u,v,w;scanf("%d%d%d",&u,&v,&w);
            G[u].pb({
    
    v,w});G[v].pb({
    
    u,w});
        }
        ans=0;
        dfs(1,0);
        printf("%d\n",ans);
    }
}

猜你喜欢

转载自blog.csdn.net/qq_43914084/article/details/106804937