Magic boy Bi Luo with his excited tree HDU - 5834(树形DP)

Bi Luo is a magic boy, he also has a migic tree, the tree has N nodes , in each node , there is a treasure, it’s value is V[i], and for each edge, there is a cost C[i], which means every time you pass the edge i , you need to pay C[i].

You may attention that every V[i] can be taken only once, but for some C[i] , you may cost severial times.

Now, Bi Luo define ans[i] as the most value can Bi Luo gets if Bi Luo starts at node i.

Bi Luo is also an excited boy, now he wants to know every ans[i], can you help him?
Input
First line is a positive integer T(T≤104) , represents there are T test cases.

Four each test:

The first line contain an integer N(N≤105).

The next line contains N integers V[i], which means the treasure’s value of node i(1≤V[i]≤104).

For the next N−1 lines, each contains three integers u,v,c , which means node u and node v are connected by an edge, it’s cost is c(1≤c≤104).

You can assume that the sum of N will not exceed 106.
Output
For the i-th test case , first output Case #i: in a single line , then output N lines , for the i-th line , output ans[i] in a single line.
Sample Input
1
5
4 1 7 7 7
1 2 6
1 3 1
2 4 8
3 5 2
Sample Output
Case #1:
15
10
14
9
15

题意:
经过一个点可以获得其价值,只能获得一次。经过边会消耗价值,且可以消耗多次。计算从任意一个节点出发的最大获得价值。

思路:
本题相当于是多个根的树形dp,很容易想到是需要换根。那么我们需要两次DP,一次自底向上,先规定某一个节点为根时候的情况。第二次自顶向下,此时有两个方向的信息,一个来自父节点,一个来自之前计算出的子树信息。

定义1为根节点,定义dp[0/1][u]代表以u出发遍历子树的最大获得,0代表返回u点,1代表不返回u点。同时定义一个最优子节点id[u] = v代表u进入哪一个子节点可以获得最大价值。

第一次递归计算出dp数组。
第二次递归则需要从父节点来的信息sum1,sum2,这次递归实际也是计算子节点的这两个数值,这里实际就是换根的操作。

sum1代表遍历父节点并回到u时的最大价值,sum1代表遍历父节点且不回到u的最大价值。
那么 a n s [ u ] = m a x ( d p [ 0 ] [ u ] + s u m 2 , d p [ 1 ] [ u ] + s u m 1 ) ans[u] = max(dp[0][u] + sum2,dp[1][u] + sum1)

但是需要注意的是,我们需要讨论当前的子节点是否是最优子节点。
如果不是的话那就直接换根,否则我们还需要计算出次大的最优子节点。

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int maxn = 1e5 + 7;

int head[maxn],nex[maxn * 2],to[maxn * 2],val[maxn * 2],tot;
int a[maxn],dp[2][maxn],id[maxn],ans[maxn];//dp[0][i]代表回到i的最大值,dp[1][i]代表不回到i的最大值

void add(int x,int y,int z)
{
    to[++tot] = y;
    nex[tot] = head[x];
    val[tot] = z;
    head[x] = tot;
}

void DP1(int u,int fa)
{
    id[u] = -1;
    for(int i = head[u];i;i = nex[i])
    {
        int v = to[i],w = val[i];
        if(v == fa)continue;
        DP1(v,u);
        int t1 = max(0,dp[0][v] - 2 * w);
        int t2 = dp[0][u] + max(0,dp[1][v] - w);
        dp[1][u] += t1;
        if(dp[1][u] < t2)
        {
            dp[1][u] = t2;
            id[u] = v;
        }
        dp[0][u] += t1;
    }
}

//sum1代表经过父节点回到u的最大值,sum2代表经过父节点后不回来的最大值
//当然要注意sum1和sum2都要不小于0,如果小于0那就不走父节点了
void DP2(int u,int fa,int sum1,int sum2)
{
    ans[u] = max(dp[0][u] + sum2,dp[1][u] + sum1);
    int D1 = dp[0][u];
    int D2 = dp[1][u];
    int ID = id[u];
    D2 += sum1;
    if(D2 <= D1 + sum2)
    {
        ID = fa;
        D2 = D1 + sum2;
    }
    D1 += sum1;

    for(int i = head[u];i;i = nex[i])
    {
        int v = to[i],w = val[i];
        if(v == fa)continue;
        if(v == ID)//之所以要判断这个,是因为sum2的时候我们要知道从父节点出去后走到哪里了
        {
            int t1 = sum1 + a[u],t2 = sum2 + a[u];
            for(int j = head[u];j;j = nex[j])
            {
                int vv = to[j],ww = val[j];
                if(vv == fa || vv == v)continue;
                int tmp = max(0,dp[0][vv] - ww * 2);
                int ttmp = t1 + max(0,dp[1][vv] - ww);//相当于寻找u的次大dp[1][u]
                t2 += tmp;
                t2 = max(ttmp,t2);
                t1 += tmp;
            }
            t1 = max(0,t1 - 2 * w);
            t2 = max(0,t2 - w);
            DP2(v,u,t1,t2);
        }
        else
        {
            int tmp = max(0,dp[0][v] - 2 * w);
            int t1 = max(0,D1 - tmp - 2 * w);
            int t2 = max(0,D2 - tmp - w);
            DP2(v,u,t1,t2);
        }
    }
}

int main()
{
    int T;scanf("%d",&T);
    int kase = 0;
    while(T--)
    {
        memset(head,0,sizeof(head));
        tot = 0;
        int n;scanf("%d",&n);
        for(int i = 1;i <= n;i++)
        {
            scanf("%d",&a[i]);dp[0][i] = dp[1][i] = a[i];
        }
        for(int i = 1;i < n;i++)
        {
            int x,y,z;scanf("%d%d%d",&x,&y,&z);
            add(x,y,z);add(y,x,z);
        }
        DP1(1,-1);
        DP2(1,-1,0,0);
        printf("Case #%d:\n",++kase);
        for(int i = 1;i <= n;i++)
        {
            printf("%d\n",ans[i]);
        }
    }
    return 0;
}

发布了756 篇原创文章 · 获赞 27 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/tomjobs/article/details/104616453