[HAOI2015]树上染色 树形背包

版权声明:https://blog.csdn.net/huashuimu2003 https://blog.csdn.net/huashuimu2003/article/details/88934726

title

LUOGU 3177
题目描述

有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

输入输出格式

输入格式:
第一行包含两个整数 N, K 。接下来 N-1 行每行三个正整数 fr, to, dis , 表示该树中存在一条长度为 dis 的边 (fr, to) 。输入保证所有点之间是联通的。

输出格式:

输出一个正整数,表示收益的最大值。

输入输出样例
输入样例#1:

3 1
1 2 1
1 3 2

输出样例#1:

3

说明

对于 100% 的数据, 0<=K<=N <=2000

analysis

应该很容易想到,dp是可做的

状态很容易想到, f [ x ] [ i ] f[x][i] 表示以 x x 为跟的子树中,选择i个黑节点,的最大值

然后我就不会做了, 去网上看了wmdcstdio神犇的题解

发现我这个状态定义是错误的,正确的状态应该是, f [ x ] [ i ] f[x][i] 表示以 x x 为根的子树中,选择i个黑节点,对答案有多少贡献

为什么是说“对答案有多少贡献呢”?

主要是想到一点,即分别考虑每条边对答案的贡献

即,边一侧的黑节点数 * 另一侧的黑节点数 * 边权+一侧的白节点数 * 另一侧的白节点数 * 边权

这点很容易证明,但是不容易想到(原因是我太弱了)

然后情况就明了了,整个问题成了一个树形背包,考虑每个子节点分配多少个黑色节点(体积),然后算出这条边对答案的贡献(价值)

这里再一次强调“贡献”,是因为这个贡献不只是在当前子树内,而是对于整棵树来说的

转移方程为 f [ x ] [ i ] = m a x ( f [ x ] [ i ] , f [ x ] [ i j ] + f [ y ] [ j ] + v a l ) f[x][i] = max( f[x][i], f[x][i-j] + f[y][j] + val )

其中 y y x x 的子节点, j j 为在这个子节点中选择的黑色点的个数, v a l val 为这条边的贡献

v a l = j ( k j ) z + ( s i z [ y ] j ) ( n k + j s i z [ y ] ) z val = j*(k-j)*z + (siz[y]-j)*(n-k+j-siz[y])*z

其中 z z 为这条边的边权, n n 为总的节点数, k k 为总的需要选择的黑色节点数, s i z [ y ] siz[y] 为以 y y 为根的子树的节点数量——转自mlystdcall

code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2010;
template<typename T>inline void read(T &x)
{
	x=0;
	T f=1, ch=getchar();
	while (!isdigit(ch) && ch^'-') ch=getchar();
	if (ch=='-') f=-1, ch=getchar();
	while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48), ch=getchar();
	x*=f;
}
int ver[maxn<<1],edge[maxn<<1],Next[maxn<<1],head[maxn],len;
inline void add(int x,int y,int z)
{
	ver[++len]=y,edge[len]=z,Next[len]=head[x],head[x]=len;
}
ll f[maxn][maxn];
int siz[maxn],n,k;
inline void dfs(int x,int fa)
{
	siz[x]=1;
	memset(f[x],-1,sizeof(f[x]));
	f[x][0]=f[x][1]=0;
	for (int i=head[x]; i; i=Next[i])
	{
		int y=ver[i];
		if (y==fa) continue;
		dfs(y,x);
		siz[x]+=siz[y];
	}
	for (int e=head[x]; e; e=Next[e])
	{
		int y=ver[e],z=edge[e];
		if (y==fa) continue;
		for (int i=min(siz[x],k); i>=0; --i)
			for (int j=0; j<=min(siz[y],i); ++j)
				if (~f[x][i-j])
				{
					ll val=(ll)j*(k-j)*z+(ll)(siz[y]-j)*(n-k+j-siz[y])*z;
					f[x][i]=max(f[x][i],f[x][i-j]+f[y][j]+val);
				}
	}
}
int main()
{
	read(n);read(k);
	for (int i=1; i<n; ++i)
	{
		int x,y,z;
		read(x);read(y);read(z);
		add(x,y,z);add(y,x,z);
	}
	dfs(1,0);
	printf("%lld\n",f[1][k]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/huashuimu2003/article/details/88934726