牛客网暑期ACM多校训练营(第二场)-H(树形dp)

题意:让你找三条链权值总和最大

题解:我们可以考虑一下在树上做dp用三维dp[u][i][j],u表示结点编号,i表示已经选取了i条链,j = 0表示没有选取该结点,j = 1表示选取了i条链里面包含u结点1条残链(即只有一头),j = 2表示选了i条链里面有一条包含i结点的完整链。

现在我们开始考虑合并:

如果不选择u结点那么我们就是选取所有儿子结点,链数总和为i的各种情况下最大值即可

如果选择u结点包含残链的话,那么就是我这个结点选取了不包含u结点的各种情况总和加上本身这个点再加上儿子各种情况,最后链条总和为i个情况,或者我们这个结点残链已经选好加上儿子的各种情况

如果选择u结点包含完整链的话,那么就是我本身这个完整链已经选好加上儿子结点的各种情况,或者我这里选了一条残链加上儿子选择了一条残链的情况

#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<cstdlib>
#include<ctime>
#include<assert.h>
#include<stack>
#include<bitset>
using namespace std;
#define mes(a,b) memset(a,b,sizeof(a))
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define dec(i,a,b) for(int i = b; i >= a; i--)
#define pb push_back
#define mk make_pair
#define fi first
#define se second
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,L,mid
#define rson rs,mid+1,R
#define lowbit(x) x&(-x)
typedef double db;
typedef long long int ll;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const ll inf = 0x3f3f3f3f;
const int mx = 4e5+5;
const int mod = 1e9+7;
const int x_move[] = {1,-1,0,0,1,1,-1,-1};
const int y_move[] = {0,0,1,-1,1,-1,1,-1};
int n,m;
ll a[mx];
ll dp[mx][4][3];
vector<int>g[mx];
void dfs(int u,int fa){
	ll tmp[4][3];
	mes(tmp,0);
	for(int i = 1; i < 4; i++)
		for(int j = 1; j < 3; j++)
			tmp[i][j] = a[u];
	for(auto v: g[u]){
		if(v==fa) continue;
		dfs(v,u);
		for(int i = 3; i >= 1; i--)
			for(int j = 0; j <= i; j++){
				if(j){
					tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][2]);
					tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][1]);
					tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][0]);
					tmp[i][2] = max(tmp[i][2],tmp[j][1]+dp[v][i-j+1][1]);
				}
				if(j!=i)
					tmp[i][2] = max(tmp[i][2],tmp[j+1][1]+dp[v][i-j][1]);

			}
		for(int i = 3; i >= 1; i--)
			for(int j = 0; j < i; j++){
				tmp[i][1] = max(tmp[i][1],dp[v][i-j][1]+tmp[j][0]+a[u]);
				if(j){
					tmp[i][1] = max(tmp[i][1],dp[v][i-j][1]+tmp[j][1]);
					tmp[i][1] = max(tmp[i][1],dp[v][i-j][0]+tmp[j][1]);
					tmp[i][1] = max(tmp[i][1],dp[v][i-j][2]+tmp[j][1]);
				}
			}
		for(int i = 3; i >= 1; i--)
			for(int j = 0; j < i; j++){
				tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][0]);
				tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][1]);
				tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][2]);
			}

	}
	for(int i = 1; i < 4; i++)
		for(int j = 0; j < 3; j++){
			dp[u][i][j] = tmp[i][j];
		}
}
int main(){
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	int t,q,ca = 1;
	scanf("%d",&n);
	for(int i = 1; i <= n; i++)
		scanf("%lld",&a[i]);
	for(int i = 2; i <= n; i++){
		int u,v;
		scanf("%d%d",&u,&v);
		g[u].pb(v);
		g[v].pb(u);
	}
	dfs(1,1);
	ll ans = 0;
	for(int i = 0; i < 4; i++)
		for(int j = 0; j < 3; j++)
			ans = max(ans,dp[1][i][j]);
	printf("%lld\n",ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/a1325136367/article/details/81252181
今日推荐