<hdu5834 Magic boy Bi Luo with his excited tree> (树形DP)

题意:一棵树有点权和边权 从每个点出发 走过一条边要花费边权同时可以获得点权 边走几次就算几次花费 点权最多算一次

   问每个点能获得的最大价值

题解:好吧 这才叫树形DP入门题

   dp[i][0]表示从i节点的儿子中走又回到i的最大值 dp[i][1]表示不回到i的最大值 dp[i][2]表示不回到i的次大值

   同时需要记录不回到i最大值的方向id[x]

   很显然 第一遍dfs可以预处理每个节点往下的值 然后关键的就是每个节点从父亲这个方向的值怎么处理

   有个很显然的结论就是 不回来是肯定比回来更优的 所以重点就是在处理不回来的这个支路在哪

   如果对于x节点其父亲的id[fa] = x 那么显然x,fa不回来的最大值是同一个支路 这个时候就可以更新两种答案

   在x下面的儿子中不回来 dp[x][1] = dp[x][1] += max(0, dp[fa][0] - cost[x][fa] * 2)

   在fa中的其他儿子中不回来就用到了次大 dp[x][1] = dp[x][0] + dp[fa][2] - cost[x][fa]

   如果id[fa] != x  dp[x][1] = dp[x][0] + dp[fa][1] - cost[x][fa]

   最后再更新dp[x][0] = dp[x][0] + max(0, dp[fa][0] - cost[x][fa] * 2) 同时转移的时候次大 以及最大的方向都要更新

   不过这里的dp值显然都是要减去重复计算的部分 具体代码见

#include <stdio.h>
#include <algorithm>
#include <iostream>
#include <string.h>
using namespace std;

int n, cnt;
int q[100005];
int dp[100005][3];
int du[100005];
int head[100005];
int id[100005];

struct node
{
    int no, to, nex, val;
}E[200005];

void dfs1(int x, int fa)
{
    dp[x][0] = q[x];
    dp[x][1] = q[x];
    dp[x][2] = q[x];
    int c = head[x];
    for(int i = c; i; i = E[i].nex)
    {
        int v = E[i].to;
        if(v == fa) continue;

        dfs1(v, x);
        if(E[i].val * 2 < dp[v][0]) dp[x][0] += dp[v][0] - E[i].val * 2;
    }

    for(int i = c; i; i = E[i].nex)
    {
        int v = E[i].to;
        if(v == fa) continue;

        int tmp = dp[x][0];
        if(E[i].val * 2 < dp[v][0]) tmp += E[i].val * 2 - dp[v][0];

        if(tmp + dp[v][1] - E[i].val >= dp[x][1])
        {
            dp[x][2] = dp[x][1];
            dp[x][1] = tmp + dp[v][1] - E[i].val;
            id[x] = v;
        }
        else if(tmp + dp[v][1] - E[i].val > dp[x][2]) dp[x][2] = tmp + dp[v][1] - E[i].val;

        dp[x][2] = max(dp[x][2], dp[x][0]);
    }
}

void dfs2(int x, int fa)
{
    int c = head[x];
    for(int i = c; i; i = E[i].nex)
    {
        int v = E[i].to;
        if(v != fa) continue;

        int tmp0 = dp[fa][0];
        int tmp1 = dp[fa][1];
        int tmp2 = dp[fa][2];
        if(E[i].val * 2 < dp[x][0])
        {
            tmp0 += E[i].val * 2 - dp[x][0];
            if(id[fa] != x) tmp1 += E[i].val * 2 - dp[x][0];
            else tmp2 += E[i].val * 2 - dp[x][0];
        }
        tmp0 = max(tmp0, 0); tmp1 = max(tmp1, 0); tmp2 = max(tmp2, 0);

        //dp[x][0] = max(dp[x][0], dp[x][0] - E[i].val * 2 + tmp0); 因为下面的转移用到了dp[x][0] 写在这里就不对
        if(tmp0 - E[i].val * 2 > 0)
        {
            dp[x][2] += tmp0 - E[i].val * 2;
            dp[x][1] += tmp0 - E[i].val * 2;
        }

        if(id[fa] == x)
        {
            if(dp[x][0] - E[i].val + tmp2 >= dp[x][1])
            {
                dp[x][2] = dp[x][1];
                dp[x][1] = dp[x][0] - E[i].val + tmp2;
                id[x] = fa;
            }
            else if(dp[x][0] - E[i].val + tmp2 > dp[x][2]) dp[x][2] = dp[x][0] - E[i].val + tmp2;
        }
        else
        {
            if(dp[x][0] - E[i].val + tmp1 >= dp[x][1])
            {
                dp[x][2] = dp[x][1];
                dp[x][1] = dp[x][0] - E[i].val + tmp1;
                id[x] = fa;
            }
            else if(dp[x][0] - E[i].val + tmp1 > dp[x][2]) dp[x][2] = dp[x][0] - E[i].val + tmp1;
        }
        dp[x][0] = max(dp[x][0], dp[x][0] - E[i].val * 2 + tmp0);
    }

    for(int i = c; i; i = E[i].nex)
    {
        int v = E[i].to;
        if(v == fa) continue;
        dfs2(v, x);
    }
}

int main()
{
    int T;
    scanf("%d", &T);
    int t = 0;

    while(T--)
    {
        t++;
        cnt = 0;
        scanf("%d", &n);
        memset(id, 0, sizeof(id));
        memset(head, 0, sizeof(head));
        memset(dp, 0, sizeof(dp));
        memset(du, 0, sizeof(du));
        for(int i = 1; i <= n; i++) scanf("%d", &q[i]);

        for(int i = 1; i < n; i++)
        {
            int u, v, w; scanf("%d%d%d", &u, &v, &w); du[u]++; du[v]++;
            E[++cnt].no = u, E[cnt].to = v, E[cnt].nex = head[u], head[u] = cnt, E[cnt].val = w;
            E[++cnt].no = v, E[cnt].to = u, E[cnt].nex = head[v], head[v] = cnt, E[cnt].val = w;
        }

        int rt;
        for(int i = 1; i <= n; i++)
            if(du[i] == 1)
            {
                rt = i;
                break;
            }

        dfs1(rt, -1);
        dfs2(rt, -1);
        printf("Case #%d:\n", t);
        for(int i = 1; i <= n; i++) printf("%d\n", max(dp[i][0], dp[i][1]));
    }
    return 0;
}
View Code

  

猜你喜欢

转载自www.cnblogs.com/lwqq3/p/9021086.html