牛客网多校2 travel(树形dp)

题目:给你一棵你个点的树,每个点有一个价值,选出三条不相交的链使得最后的总价值最大。

显然是树形dp,我定义了一个三维的dp[i][j][k](i为子树id,j为选择了几条链,k为选择的j条链中是否包含了经过i点的直链)(0<=j<=3,0<=k<=1)。

昨晚写这个题的时候一直就是过48%的数据,真是烦啊,今天发现定义一个全局变量的数组,里面递归时出错,真是无语了。改过来后发现是过了84%的数据,然后又陷入了怎么改都是84%的错误里。吃完饭后又Wa了几次发现掉了一个转移情况,我真的是太垃圾了。最后说一句,好题!

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int maxn=4e5+10;
int n,t;
ll val[maxn],dp[maxn][4][2];
vector<int>a[maxn];
void dfs(int u,int fa)
{
    ll tp[4][3]={0},tmp[4][3];//用tp来存已经被选进去的子树的最优值,tmp存两部分合并的最优值。tp数组不要定义在全局变量啊!!!
    for(int i=0;i<a[u].size();i++)
    {
        int v=a[u][i];
        if(v==fa) continue;
        dfs(v,u);
        memset(tmp,0,sizeof tmp);

        for(int j=0;j<=3;j++)
        for(int k=0;k+j<=3;k++)
        {
            tmp[j+k][0]=max(tmp[j+k][0],tp[j][0]+dp[v][k][0]);//不包含直链的情况
        }

        for(int j=0;j<=3;j++)
        for(int k=0;k+j<=3;k++)//包含一条直链,看是tp还是v贡献的最优的
        {
            if(j>0) tmp[j+k][1]=max(tmp[j+k][1],tp[j][1]+dp[v][k][0]);
            if(k>0) tmp[j+k][1]=max(tmp[j+k][1],tp[j][0]+dp[v][k][1]);
        }

        for(int j=0;j<=3;j++)
        for(int k=0;k+j<=3;k++)//之前选进来的部分中根已经参与了构成一条完整链的情况
            tmp[j+k][2]=max(tmp[j+k][2],tp[j][2]+dp[v][k][0]);

        for(int j=1;j<=3;j++)
        for(int k=1;k<=3;k++)
        if(j+k<=4)
            tmp[j+k-1][2]=max(tmp[j+k-1][2],tp[j][1]+dp[v][k][1]+val[u]);//根参与构成新的完整的链,两部分都要贡献出一条直链

        for(int j=0;j<=3;j++)
            for(int k=0;k<=2;k++)
            tp[j][k]=tmp[j][k];
    }
    for(int i=1;i<=3;i++)
    {
        dp[u][i][0]=max(dp[u][i][0],tp[i][0]);//根没有参与构成完整链
        dp[u][i][0]=max(dp[u][i][0],tp[i][1]+val[u]);//直链当完整链来算来更新一次
        dp[u][i][0]=max(dp[u][i][0],tp[i][2]);//根参与构成完整链
    }//i条链中没有直链

    for(int j=1;j<=3;j++)
        dp[u][j][1]=tp[j][1]+val[u];//j条链中有直链
}
int main()
{
    ///freopen("in.txt","r",stdin);
    while(~scanf("%d",&n))
    {
        memset(dp,0,sizeof dp);
        for(int i=1;i<=n;i++)
        {
            scanf("%lld",&val[i]);
            a[i].clear();
        }
        for(int i=1;i<n;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            a[x].push_back(y);
            a[y].push_back(x);
        }
        dfs(1,-1);
        ll ans=0;
        for(int i=1;i<=3;i++)
        {
            ans=max(ans,dp[1][i][0]);
            ans=max(ans,dp[1][i][1]);
        }
        printf("%lld\n",ans);
    }
    return 0;
}
/*
13
10 10 10 10 10 1 10 10 10 1 10 10 10
1 2
2 3
3 4
4 5
2 6
6 7
7 8
7 9
6 10
10 11
11 12
11 13

4
10 10 10 10
1 2
1 3
1 4

3
1 2 0
1 2
2 3
*/

猜你喜欢

转载自blog.csdn.net/dllpxfire/article/details/81153615