【题解】 poj 2486 Apple Tree(树形dp)

题目大意:

给你n个结点,最大步数k。接下来n个数字表示每个节点有多少个苹果,然后n-1行每行两个数,代表两个结点之间有边相连。读入多组数据,求在限定步数k内能吃到最多的苹果数。

对于这道题目,由于是个树形结构,又要求限定步数内的最大值,我们可以往dp方向联想,那正解就是树形dp了。我们用邻接表存图,并设back[n][k]代表从n结点出发,走k步所能获得的最大苹果数(回到n结点)。那么答案就是dp[1][k],为了求出dp数组,我们应当使用dfs。dfs(x,fa)代表从x结点出发,其父亲结点为fa。我们可以知道在k步过后最好让终止的结点不是最初的结点,也就是尽量让它留在子树里,所以设dp数组记录与back相同的内容,但它不回到n结点。由于不管走多少步根节点的苹果一定会被吃掉,所以初始化dp[x][k]=back[x][k]=apple[x]。枚举与父亲节点相邻的子节点,如果下一个节点是父亲结点就continue,不是则dfs下去。然后就把这个问题转化为分组背包,设x结点的一个子节点为now,走的步数为p,总步数为l,得到状态转移方程:back[x][l]=max(back[x][l],back[x][l-p-2]+back[now][p]);  注意,这里是l-p-2的原因是因为back数组最后会回到原点,去一次回来一次需要两步。重新枚举与父亲结点相邻的子节点,为了求dp数组我们需要设一个tmpback[k]的一维数组,代表去掉某个节点后,k步内可以走的最大步数。初始化tmpback[k]=apple[x]。我们想去掉y结点来求tmpback,所以再次枚举与x相邻的结点,设一个q结点与x相连且不为y,利用和求back类似的方法,求得tmpback状态转移方程为tmpback[l]=max(tmpback[l],tmpback[l-p-2]+back[q][p])。最后可以表示dp的状态转移方程dp[x][l]=max(dp[x][l],dp[now][l-p-1]+tmpback[p]) 注意这次的dp数组它不会返回父亲结点,所以只需要l-p-1就够了,最后得到正确的结果。



#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<iomanip>
using namespace std;
const int maxn=110;
const int maxk=210;
int head[maxn],nnext[maxn*2],to[maxn*2];
int dp[maxn][maxk];
int tmpback[maxk];
int back[maxn][maxk];
int tot;
int n,k;
int apple[maxn];
void add(int x,int y)
{
	tot++;
	nnext[tot]=head[x];
	head[x]=tot;
	to[tot]=y;
}
void dfs(int x,int fa)
{
	for(int i=0;i<=k;i++) dp[x][i]=back[x][i]=apple[x];
	for(int i=head[x];i;i=nnext[i])  //分组 
	{
		int now=to[i];
		if(now==fa) continue;
		
		dfs(now,x);
		
		for(int l=k;l>=0;l--)        
			for(int p=0;l-p-2>=0;p++)
				back[x][l]=max(back[x][l],back[x][l-p-2]+back[now][p]);
	}
	
	for(int i=head[x];i;i=nnext[i])
	{
		int now=to[i];
		if(now==fa) continue;
		for(int l=0;l<=k;l++) tmpback[l]=apple[x];
		
		for(int y=head[x];y;y=nnext[y])
		{
			int q=to[y];
			if(q==fa||q==now) continue;
			
			for(int l=k;l>=0;l--)        
				for(int p=0;l-p-2>=0;p++)
					tmpback[l]=max(tmpback[l],tmpback[l-p-2]+back[q][p]);
		} 
		
		for(int l=k;l>=0;l--)
			for(int p=0;l-p-1>=0;p++)
				dp[x][l]=max(dp[x][l],dp[now][l-p-1]+tmpback[p]);
	}
}
int main()
{
	while(scanf("%d%d",&n,&k)==2)
	{
		tot=0;
		memset(back,0,sizeof(back));
		memset(dp,0,sizeof(dp));
		memset(head,0,sizeof(tmpback));
		for(int i=1;i<=n;i++)
		{
			cin>>apple[i];
		}
		for(int i=1;i<=n-1;i++)
		{
			int x,y;
			cin>>x>>y;
			add(x,y);
			add(y,x);
		}
		dfs(1,0);
		cout<<dp[1][k]<<endl;
	}
	return 0;
}



猜你喜欢

转载自blog.csdn.net/Rem_Inory/article/details/81054662