【树形DP && 求树上每个点能到达的最远距离】HDU - 2196 Computer

Step1 Problem:

给你一棵n个节点的树,求每个点的最远距离。

Step2 Involving algorithms:

树形DP

Step3 Ideas:

求出一个点,子树方向的最远距离 和 父亲方向的最远距离,取最大值就是该点的最远距离。
子树方向的最远距离,一遍dfs回溯更新就求出来了。
核心是 父亲方向的最远距离
状态 dist[u][0]:代表 u 这个点 子树方向最远距离
状态 dist[u][1]:代表 u 这个点 子树方向次远距离
状态 dist[u][2]:代表 u 这个点 父亲方向最远距离
设 v 是 u 的孩子 dis(u,v)是 u 和 v 之间距离
如果 v 最远距离在 u 的父亲方向
dist[v][2] = dist[u][2] + dis(u, v);
否则 v 最远距离在 u 的子树方向
{
如果 v 在 子树方向最远距离路径上
dist[v][2] = dist[u][1] + dis(u, v);
如果 v 不在 子树方向最远距离路径上
dist[v][2] = dist[u][0] + dis(u, v);
}

Step4 Code:

#include<bits/stdc++.h>
using namespace std;
const int N = 1e4+100;
struct node
{
    int to, w, next;
};
node Map[2*N];
int head[N], cnt;
int dist[N][3], pot[N];
void dfs1(int u, int f)
{
    dist[u][0] = dist[u][1] = dist[u][2] = 0;
    for(int i = head[u]; ~i; i = Map[i].next)
    {
        int to = Map[i].to, w = Map[i].w;
        if(to != f)
        {
            dfs1(to, u);
            if(dist[to][0]+w >= dist[u][0])
            {
                pot[u] = to;
                dist[u][1] = dist[u][0];
                dist[u][0] = dist[to][0]+w;
            }
            else if(dist[to][0]+w > dist[u][1])
                dist[u][1] = dist[to][0]+w;
        }
    }
}
void dfs2(int u, int f)
{
    for(int i = head[u]; ~i; i = Map[i].next)
    {
        int to = Map[i].to, w = Map[i].w;
        if(to != f)
        {
            if(pot[u] == to)
                dist[to][2] = w + max(dist[u][1], dist[u][2]);
            else dist[to][2] = w + max(dist[u][0], dist[u][2]);
            dfs2(to, u);
        }
    }
}
void add(int u, int v, int w)
{
    Map[cnt] = (node){v, w, head[u]};
    head[u] = cnt++;
    Map[cnt] = (node){u, w, head[v]};
    head[v] = cnt++;
}
int main()
{
    int n, v, w;
    while(~scanf("%d", &n))
    {
        memset(head, -1, sizeof(head));
        cnt = 0;
        for(int i = 2; i <= n; i++)
        {
            scanf("%d %d", &v, &w);
            add(i, v, w);
        }
        dfs1(1, -1);
        dfs2(1, -1);
        for(int i = 1; i <= n; i++)
            printf("%d\n", max(dist[i][0], dist[i][2]));
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/bbbbswbq/article/details/80047027