BZOJ4033 [HAOI2015]树上染色 [树形DP]

BZOJ4033 [HAOI2015]树上染色 [树形DP]

Description

有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。

问收益最大值是多少。

Input

第一行两个整数N,K。

接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。

输入保证所有点之间是联通的。

N<=2000,0<=K<=N

Output

输出一个正整数,表示收益的最大值。

题解

树形DP:设 f [ i ] [ j ] 表示以 i 为根的子树中有 j 个黑点。有坑点,详见代码注释

一条链的收益和,容易想到的是暴力,就是把每条链找出来,然后加入Ans。但这道题光点就有2000,所以要考虑更优秀的算法,即不能从找链入手

可以考虑从线段被使用了多少次,即线段对答案的总贡献入手,这样就不需要找出链具体是哪些。而现在已知以 i 为根的子树内有 j 个黑点,那么这棵子树以外的黑点数也是已知的,为 k j 。那么对于边 e ,它对答案的贡献是:

L e n [ e ] + L e n [ e ]

状态……看代码吧……

经验

  • 分解步骤。一条链对答案的贡献可以看做每一个线段对答案的贡献;一个数对答案的贡献可以看做每一个二进制数数位对答案的贡献。
  • 减少枚举上限。防止时间复杂度退化。
  • 子树两两合并。

代码

#include<cstdio>
#include<iostream>
#define NN 2100
#define ll long long
using namespace std;
ll N,K;
ll Size[NN],f[NN][NN];
ll End[NN<<1],Last[NN<<1],Next[NN<<1],Len[NN<<1],cnt;
void DFS(ll u,ll fa){
    Size[u]=1;
    for(ll i=Last[u];i;i=Next[i]){
        ll v=End[i];
        if(v==fa)continue;
        DFS(v,u);
        //注意此处不要从n开始枚举,否则时间复杂度会充O(n^2)退化成O(n^3)
        for(ll j=min(Size[u],K);j>=0;j--)
            //j表示的是以u为根的子树中有多少黑点(注意:不包含以v为根的子树!)
            //而且目前的统计还不包含没讨论过的子树,即v和v以前的子树(这一点从Size的更新就可以看出)

            //一定要倒着枚举,因为这样才不会在这一轮更新中,让先更新的答案影响到后更新的答案
            //如果非要正着枚举,也可以开一个临时数组
            //总的来说,就是要防止重复更新
            //这是因为f数组的特性:它在更新完以v为根的这棵子树前,答案都是v以后的子树贡献的
            //因此要把以v为根的子树更新完之后才能最终更新f数组的答案(或者在不影响其它答案的情况下更新f数组)
            for(ll k=min(Size[v],K);k>=0;k--){
                //k表示的是以v为根的子树中有多少黑点
                if(j+k>K)continue;
                ll t=Len[i]*(k*(K-k)+(Size[v]-k)*(N-K-Size[v]+k));
                //一条边对答案的贡献(详见“题解”)
                f[u][j+k]=max(f[u][j+k],f[u][j]+f[v][k]+t);
            }
        Size[u]+=Size[v];
    }
}
void Ins(ll x,ll y,ll w){
    End[++cnt]=y,Len[cnt]=w;
    Next[cnt]=Last[x],Last[x]=cnt;
}
int main(){
    scanf("%lld%lld",&N,&K);
    for(ll i=1;i<N;i++){
        ll u,v,w;scanf("%lld%lld%lld",&u,&v,&w);
        Ins(u,v,w),Ins(v,u,w);
    }
    DFS(1,0);
    printf("%lld",f[1][K]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/ArliaStark/article/details/81321820