洛谷3117 BZOJ4033 树上染色 树形背包

题目链接
题意:
给你一棵n个点的数,边有边权,要在其中选k个点染成黑色,其余点染成白色,求所有相同颜色点之间的路径的权值之和。

题解:
这个题不难想到要树形dp,状态设计也还好,但是要如何设计dp含义、如何统计答案是有难度的。我们不难想到dp有一维应该设计成以x为根的子树的情况,另一维设计成子树内染了i个黑色点,但是如果dp数组的含义设为子树内的权值之和的话似乎好像很难向父节点转移,因为多了一些新的点之后这些新的点也会与原来子树内的点形成路径,对答案产生影响。
所以我们应该设dp数组的含义是以x为根的子树里选了k个点染成黑色的情况下子树的路径对整个答案的贡献。由于我们可以处理出每个子树的大小,所以我们确定了子树黑点个数的情况下就可以相减得到白点个数。
这样我们考虑每条路径对最终答案的贡献,一条路径会被经过的次数可以看作从这条边把树分成两半,那么经过的次数就是左侧白点的数量乘右侧白点的数量加上左侧黑点的数量乘右侧黑点的数量(乘法原理),贡献的话就是次数乘上权值就好了。然后我们考虑如何给子树分配这k个黑点,做法就是树形背包。
值得注意的是这道题的复杂度,这个复杂度看起来是 O ( n 3 ) 的,但是仔细考虑我们是枚举了一棵树的所有子树进行背包,每次背包序列长度是右子树大小影响的,这个可以转化为一棵树的点对个数,因为任意两点只会在他们的LCA处对复杂度产生 O ( 1 ) 的影响。做背包时的每一个i,j可以对应为子树的两个点,所以是在两个点的LCA处产生 O ( 1 ) 的复杂度贡献,一共 n 2 级别的点对很显然吧(每个点都与除了自己外的所有点形成点对, n ( n 1 ) )。
这题代码在BZOJ上T了,但是在洛谷上跑得是很快的,不知道为什么,弄到BZOJ测评数据在本机测也是非常快的。
代码:

#include <bits/stdc++.h>
using namespace std;

int n,hed[2010],cnt,k,sz[2010];
long long dp[2010][2010];
struct node
{
    int to,dis,next;
}a[20010];
inline int read()
{
    int x=0;
    char s=getchar();
    while(s>'9'||s<'0')
    s=getchar();
    while(s>='0'&&s<='9')
    {
        x=(x<<1)+(x<<3)+s-'0';
        s=getchar();
    }
    return x;
}
inline void add(int from,int to,int dis)
{
    a[++cnt].to=to;
    a[cnt].dis=dis;
    a[cnt].next=hed[from];
    hed[from]=cnt;
}
void dfs(int x,int fa)
{
    dp[x][0]=0;
    dp[x][1]=0;
    sz[x]=1;
    for(register int i=hed[x];i;i=a[i].next)
    {
        register int y=a[i].to;
        if(y!=fa)
        {
            dfs(y,x);
            sz[x]+=sz[y];
            for(register int j=min(k,sz[x]);j>=0;--j)
            {
                for(register int p=0;p<=min(k,sz[y]);++p)
                {
                    if(dp[x][j-p]>=0)
                    {
                        register long long ji=(long long)a[i].dis*(k-p)*p+(long long)a[i].dis*(n-k-(sz[y]-p))*(sz[y]-p);
                        dp[x][j]=max(dp[x][j],dp[x][j-p]+dp[y][p]+ji);
                    }
                }
            }
        }
    }

}
int main()
{
    n=read();
    k=read();
    for(register int i=1;i<=n-1;++i)
    {
        register int x,y,z;
        x=read();
        y=read();
        z=read();
        add(x,y,z);
        add(y,x,z);
    }
    memset(dp,-0x3f,sizeof(dp));
    dfs(1,0);
    printf("%lld\n",dp[1][k]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/forever_shi/article/details/81292753