题意:让你找三条链权值总和最大
题解:我们可以考虑一下在树上做dp用三维dp[u][i][j],u表示结点编号,i表示已经选取了i条链,j = 0表示没有选取该结点,j = 1表示选取了i条链里面包含u结点1条残链(即只有一头),j = 2表示选了i条链里面有一条包含i结点的完整链。
现在我们开始考虑合并:
如果不选择u结点那么我们就是选取所有儿子结点,链数总和为i的各种情况下最大值即可
如果选择u结点包含残链的话,那么就是我这个结点选取了不包含u结点的各种情况总和加上本身这个点再加上儿子各种情况,最后链条总和为i个情况,或者我们这个结点残链已经选好加上儿子的各种情况
如果选择u结点包含完整链的话,那么就是我本身这个完整链已经选好加上儿子结点的各种情况,或者我这里选了一条残链加上儿子选择了一条残链的情况
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<cstdlib>
#include<ctime>
#include<assert.h>
#include<stack>
#include<bitset>
using namespace std;
#define mes(a,b) memset(a,b,sizeof(a))
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define dec(i,a,b) for(int i = b; i >= a; i--)
#define pb push_back
#define mk make_pair
#define fi first
#define se second
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,L,mid
#define rson rs,mid+1,R
#define lowbit(x) x&(-x)
typedef double db;
typedef long long int ll;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const ll inf = 0x3f3f3f3f;
const int mx = 4e5+5;
const int mod = 1e9+7;
const int x_move[] = {1,-1,0,0,1,1,-1,-1};
const int y_move[] = {0,0,1,-1,1,-1,1,-1};
int n,m;
ll a[mx];
ll dp[mx][4][3];
vector<int>g[mx];
void dfs(int u,int fa){
ll tmp[4][3];
mes(tmp,0);
for(int i = 1; i < 4; i++)
for(int j = 1; j < 3; j++)
tmp[i][j] = a[u];
for(auto v: g[u]){
if(v==fa) continue;
dfs(v,u);
for(int i = 3; i >= 1; i--)
for(int j = 0; j <= i; j++){
if(j){
tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][2]);
tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][1]);
tmp[i][2] = max(tmp[i][2],tmp[j][2]+dp[v][i-j][0]);
tmp[i][2] = max(tmp[i][2],tmp[j][1]+dp[v][i-j+1][1]);
}
if(j!=i)
tmp[i][2] = max(tmp[i][2],tmp[j+1][1]+dp[v][i-j][1]);
}
for(int i = 3; i >= 1; i--)
for(int j = 0; j < i; j++){
tmp[i][1] = max(tmp[i][1],dp[v][i-j][1]+tmp[j][0]+a[u]);
if(j){
tmp[i][1] = max(tmp[i][1],dp[v][i-j][1]+tmp[j][1]);
tmp[i][1] = max(tmp[i][1],dp[v][i-j][0]+tmp[j][1]);
tmp[i][1] = max(tmp[i][1],dp[v][i-j][2]+tmp[j][1]);
}
}
for(int i = 3; i >= 1; i--)
for(int j = 0; j < i; j++){
tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][0]);
tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][1]);
tmp[i][0] = max(tmp[i][0],tmp[j][0]+dp[v][i-j][2]);
}
}
for(int i = 1; i < 4; i++)
for(int j = 0; j < 3; j++){
dp[u][i][j] = tmp[i][j];
}
}
int main(){
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
int t,q,ca = 1;
scanf("%d",&n);
for(int i = 1; i <= n; i++)
scanf("%lld",&a[i]);
for(int i = 2; i <= n; i++){
int u,v;
scanf("%d%d",&u,&v);
g[u].pb(v);
g[v].pb(u);
}
dfs(1,1);
ll ans = 0;
for(int i = 0; i < 4; i++)
for(int j = 0; j < 3; j++)
ans = max(ans,dp[1][i][j]);
printf("%lld\n",ans);
return 0;
}