luogu题解 P1099 【树网的核】树的直径变式+数据结构维护

  • 题目链接:

https://www.luogu.org/problemnew/show/P1099

https://www.lydsy.com/JudgeOnline/problem.php?id=1999 (加强版)

  • 分析:

首先你需要\(O(N)\)求树的直径的前置技能,其实很简单,先随便找个根找到树上距离它最远的顶点\(S\),然后以\(S\)为根找树上距离\(S\)最远的顶点\(E\),\(S,E\)之间的路径就是树的直径

  1. 暴力枚举

    按照题目要求说的去做就好了,求出树的直径,然后根据贪心找长度小于等于s的路径搜一遍就好了,时间复杂度\(O(N^2)\) 代码难度小 可过NOIP数据

  2. 预处理扫描

    还是先求出树的直径,然后通过仔细分析偏心距怎么求,发现无非三种情况

    1.直径最左端顶点到路径最左端顶点距离

    2.路径上各点不经过直径上其他点到达的最大距离

    3.直径上最右端顶点到路径最右端顶点的距离

    我相信2理解不难,现在简单证明1情况(3同理)的正确性:假设路径最左端有一条向左延伸的路径其终点\(E'\)不是直径左端点而距离却更大,则违反直径定义,因为这样\(E'\)与路径左端点连接能形成一条更长的直径

    于是我们需要四遍DFS(也许你写的好的话并不需要这么多遍)

    前两遍就是求树的直径,用一个\(dmet[]\)将直径各点按顺序记录方便区间处理,\(diameter\)为直径长度

    第三遍求直径上各点不经过直径上其他点达到的最大距离,用\(d[]\)记录

    第四遍求直径左右段点到直径各点距离,用\(ld[],rd[]\)记录

    然后我们扫描长度小于等于s的路径(可能是一个点),找上述三种情况的最大值,找第二种情况最大值你可以用滑动窗口,ST Table,线段树在区间查找最大值,这里采用线段树

    时间复杂度:\(O(N\log N)\)

    可以过比较大的数据,这是加强版

    https://www.lydsy.com/JudgeOnline/problem.php?id=1999

  • 代码

写得比较丑,见谅

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cctype>
#include <cstring>
#include <queue>
#define ll long long 
#define ri register int 
using namespace std;
const int maxn=500005;
const int maxm=1000005;
const int inf=0x7fffffff;
template <class T>inline void read(T &x){
    x=0;int ne=0;char c;
    while(!isdigit(c=getchar()))ne=c=='-';
    x=c-48;
    while(isdigit(c=getchar()))x=(x<<3)+(x<<1)+c-48;
    x=ne?-x:x;
    return ;
}
struct Edge{
    int ne,to,dis;
}edge[maxm];
int h[maxn],num_edge=0,mx=-inf;
bool vis[maxn];
void add(int f,int to,int dis){
    edge[++num_edge].ne=h[f];
    edge[num_edge].to=to;
    edge[num_edge].dis=dis;
    h[f]=num_edge;
}
int n,s,head,tail;
void dfs_1(int fa,int cur,int cnt){
    for(ri i=h[cur];i;i=edge[i].ne){
        if(edge[i].to!=fa)dfs_1(cur,edge[i].to,cnt+edge[i].dis);
    }
    if(cnt>mx){
        mx=cnt,head=cur;
    }
    return ;
}
int pre[maxn],dmet[maxn],d[maxn],diameter;
void dfs_2(int fa,int cur,int cnt){
    for(ri i=h[cur];i;i=edge[i].ne){
        if(edge[i].to!=fa){
            pre[edge[i].to]=cur;
            dfs_2(cur,edge[i].to,cnt+edge[i].dis);
        }
    }
    if(cnt>mx){
        mx=cnt,tail=cur;
    }
    return ;
}
int root;
void dfs_3(int fa,int cur,int cnt){
    for(ri i=h[cur];i;i=edge[i].ne){
        int v=edge[i].to;
        if(!vis[v]&&v!=fa){
            dfs_3(cur,v,cnt+edge[i].dis);
        }
    }
    if(cnt>mx){
        mx=cnt,d[root]=cnt;
    }
    return ;
}
int ld[maxn],rd[maxn],fun[maxn];
void dfs_4(int fa,int cur,int cnt){
    for(ri i=h[cur];i;i=edge[i].ne){
        int v=edge[i].to;
        if(vis[v]&&v!=fa){
            dfs_4(cur,v,cnt+edge[i].dis);
        }
    }
    ld[fun[cur]]=cnt,rd[fun[cur]]=diameter-cnt;
    return ;
}
int maxx[maxn<<2],L,R;
void build(int now,int l,int r){
    if(l==r){
        maxx[now]=d[dmet[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(now<<1,l,mid);
    build(now<<1|1,mid+1,r);
    maxx[now]=max(maxx[now<<1],maxx[now<<1|1]);
    return ;
}
int query(int now,int l,int r){
    if(L<=l&&r<=R){
        return maxx[now];
    }
    int mid=(l+r)>>1,ans=-inf;
    if(L<=mid)ans=max(ans,query(now<<1,l,mid));
    if(mid<R)ans=max(ans,query(now<<1|1,mid+1,r));
    return ans;
}
int main(){
    int x,y,z;
    read(n),read(s);
    for(ri i=1;i<n;i++){
        read(x),read(y),read(z);
        add(x,y,z);
        add(y,x,z);
    }
    mx=-inf;
    dfs_1(0,1,0);
    mx=-inf;
    dfs_2(0,head,0);
    diameter=mx;  //两次DFS找直径 
/*------------------*/
    int tmp=tail,tot=0;
    while(tmp!=head){
        dmet[++tot]=tmp;
        fun[tmp]=tot;
        vis[tmp]=1;
        tmp=pre[tmp];
    }
    dmet[++tot]=head;
    fun[head]=tot;
    vis[head]=1;  //记录直径 
/*------------------*/
    for(ri i=1;i<=tot;i++){
        root=dmet[i];
        mx=-inf;
        dfs_3(0,dmet[i],0);
    }             //找从直径各点不经过其他直径上的点走过的最大长度 
    dfs_4(0,tail,0);//找直径两端点到直径各点距离 
    build(1,1,tot);//建线段树,不会ST Table的我 
/*-----------------*/
    int ans=inf,l=1,r=1;
    while(ld[r+1]-ld[l]<=s&&r!=tot)r++;
    //cout<<tot<<'*'<<endl;
    while(r<=tot){//cout<<l<<' '<<r<<endl;
        L=l,R=r;
        ans=min(ans,max(query(1,1,tot),max(ld[l],rd[r])));
        if(r==tot)break;
        l++;
        while(ld[r+1]-ld[l]<=s&&r!=tot)r++;
    }
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Rye-Catcher/p/9248789.html
今日推荐