牛客网暑期ACM多校训练营(第二场) H.travel (树形DP)

题目链接

时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld

题目描述

White Cloud has a tree with n nodes.The root is a node with number 1. Each node has a value.
White Rabbit wants to travel in the tree 3 times. In Each travel it will go through a path in the tree.
White Rabbit can't pass a node more than one time during the 3 travels. It wants to know the maximum sum value of all nodes it passes through.

输入描述:

The first line of input contains an integer n(3 <= n <= 400001)
In the next line there are n integers in range [0,1000000] denoting the value of each node.
For the next n-1 lines, each line contains two integers denoting the edge of this tree.

输出描述:

Print one integer denoting the answer.

示例1

输入

13
10 10 10 10 10 1 10 10 10 1 10 10 10
1 2
2 3
3 4
4 5
2 6
6 7
7 8
7 9
6 10
10 11
11 12
11 13

输出

110

题意:给出一个n,表示1~n个点,随后一行给出n个值,表示这n个点对应的点权值,再给出n-1条无向边,即n-1行u和v.求从这颗树上任意取3条节点不可共享的链所能获得的最大权值是多少?

题解:这题明显就是典型的树形DP题,不过需要用到些背包的贪心思想来实现取三条不想交且权值和最大的链.

        那么我们先定义一下:

w[ ];              //点权值
f[ ][ ];            //f[u][j]表示当前点u取j条链的最佳状态
g[ ][ ];           //g[u][j]表示当前点u取j条链的最佳状态,且再取多当前节点其上升链(具体指该节点所在的链还能向上延伸)

        之后我们只需要O(n)时间复杂度内完成对树的遍历获得f[起点][3]即可得到ans.(起点任意,我选择的是1号点)

        在每次dfs到当前节点u时,我们需要再定义三个数组:

pre_son[4][3] = { 0 };  //保存当前节点u已遍历过的子节点的情况,[i][j]表示子节点取i条链的情况,j=0,1,2分别表示状态ff,fg,gg
now_son[4][3];           //同上,不过保存的是当前子节点的情况
tmp[4][3];                   //保存中间变量(背包思想:当前节点u的背包情况)

        以上说的ff,fg和gg状态都是对于当前节点的子节点而言的,即取子节点的f数组数据和g数组数据来进行统计;

        而且以上三个数组的 j 都可以值得大小可以直接的当做选取当前节点u的子节点g链的条数;

        那么对于四个for循环部分可以这样理解其背包思想,对于当前节点的子节点直接没有直接联系,因此其内获取的链最优情况是可以直接相加的,即f[v1][1]+f[v2][1]=f[u][2](不过这里还不能同真正意义上的表示为f[u][2],因为还有当前节点的权值没有加)

       对于tmp的第二维是当前节点g链的条数,最多只能有2条,在加上当前节点后会得到如下图情况:(这是gg的情况变化,即2个g)

=>

 (需要注意的上面这2个图是更新 f 数组的其中一种情况)

那么下面就是如何更新当前节点的 f 数组和 g 数组了:

对于f的更新如下:

       我们单纯就只需要判断把当前节点加入其子树统计来的贪心结果进行更新就可以;

对于g的更新:

         子树更新0条g链合并到节点u上并形成当前节点的g链如下:

=>

         子树更新1条g链合并到节点u上并形成当前节点的g链如下:

=>

剩下的就要靠读者自己理解代码了,这里解释的应该足够详细了,至少说明白了 f 数组和 g 数组还有递归函数中三个数组的定义,虽然对 f 和 g 的合并只是简单的提点,之后只需要通过这些就很好理解代码了(相信你会变得更加强大~)

代码如下:

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<string>
#include<cstring>
#include<vector>
using namespace std;
#define ll long long
const int maxn = 4e5 + 10;
ll w[maxn];                   //点权值
ll f[maxn][4];                //f[u][j]表示当前点u取j条链的最佳状态
ll g[maxn][4];                //g[u][j]表示当前点u取j条链的最佳状态,且再取多当前节点其上升链(具体指该节点所在的链还能向上延伸)
vector<int>e[maxn];           //图
void dfs(int u, int fa) {
	ll pre_son[4][3] = { 0 }; //保存当前节点u已遍历过的子节点的情况,[i][j]表示子节点取i条链的情况,j=0,1,2分别表示状态ff,fg,gg
	ll now_son[4][3];         //同上,不过保存的是当前子节点的情况
	ll tmp[4][3];             //保存中间变量(背包思想:当前节点u的背包情况)
	for (int x = 0; x < e[u].size(); x++) {
		int v = e[u][x];
		if (v == fa) continue;
		dfs(v, u);

		memset(now_son, 0, sizeof(now_son));
		memset(tmp, 0, sizeof(tmp));

		//对于当前子节点v的结果保存到now_son中,其贡献只有ff和fg状态
		for (int i = 0; i <= 3; i++) {      
			now_son[i][0] = f[v][i];
			now_son[i][1] = g[v][i];
		}

		//对于儿子中存在链的条数可以直接相加,但总的条数小于等于3
		//对于儿子中取g得条数也可以直接相加,当总的条数小于等于2
		//使用背包的方式得到当前tmp数组的最优结果
		for (int i = 0; i <= 3; i++)        
			for (int j = 0; i + j <= 3; j++)
				for (int p = 0; p <= 2; p++)
					for (int q = 0; p + q <= 2; q++)
						tmp[i + j][p + q] = max(tmp[i + j][p + q], pre_son[i][p] + now_son[j][q]);
		memcpy(pre_son, tmp, sizeof(pre_son));
	}

	//先初始化当前点的f数组
	for (int i = 0; i <= 3; i++)                            
		f[u][i] = pre_son[i][0];

	//对于f数组的更新:
	//对于j=0时即是ff状态,加上当前节点,当前节点作为一条单独的链
	//j=1时即是fg状态,那么加上当前节点作为g链上一点
	//j=2时即是gg状态,那么加上当前节点作为gg链的最近公共祖先,从而形成一条链
	//剩下的就是本身在子节点那么背包对比来取0~3条链的最优情况,实现f数组的更新
	for (int i = 1; i <= 3; i++)             
		for (int j = 0; j <= 2; j++)                         
			f[u][i] = max(f[u][i], pre_son[i - 1][j] + w[u]);

	//对于g数组的更新:
	//因为需要更新当前节点为根节点的链条数加上一条上升链
	//那么结果只能是当前节点其儿子节点最优情况中取i=0~3条链的情况
	//且需保证其儿子只能是ff或fg状态,给其加上当前节点权值即可完成更新
	for (int i = 0; i <= 3; i++)
		for (int j = 0; j <= 1; j++)
			g[u][i] = max(g[u][i], pre_son[i][j] + w[u]);    
}
int main(){
	int n, u, v;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) 
		scanf("%lld", &w[i]);
	for (int i = 1; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs(1, -1);
	printf("%lld\n", f[1][3]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_41156591/article/details/81200753
今日推荐