题目
给你一棵n(n<=3000)个点的树,树上每个点对(u,v)对答案的贡献是mex(u,v),
mex就是sg函数里的那个mex
每个点对只被统计一次,令所有点对的贡献和最大,输出贡献和
思路来源
https://www.bilibili.com/video/av84326197?p=5
题解
现在看不懂的题解都得搜b站了 orz橙名玩家在线讲题
首先,只有一条链最后对答案有贡献,值域只有连续才有答案,在这条长为len链上填0 1 2 … len-1
所以,枚举这条链的两个端点,剩下的用记忆化搜索,复杂度O(n²)
对于一条链的情况,就类似区间dp,先确定一个0,再补1的时候带来的贡献就是链两端的sz
只考虑这个补增量的过程,dp[x][y]是可以从dp[x+1][y]或dp[x][y-1]转移而来的,
相当于补了一条x和x+1之间的边,或者y-1和y之间的边,
扫描二维码关注公众号,回复:
8871470 查看本文章
补了这个边之后,对答案增量带来的贡献是x及其左边的size乘y及其右边的size,
那么就定义三个东西,dp[x][y]代表x-y这条链的答案最大值,
sz[x][y]代表以x为根时y的子树的大小,par[x][y]代表以x为根时y的父亲
注意递归到x、y仅一条边相连时,dp[x][y]=sz[x][y]*sz[y][x]+max(sz[x][x],sz[y][y])
而符合题意的显然是dp[x][y]=sz[x][y]*sz[y][x],所以x==y时递归终点返回0即可
树上的情形,实际dp[x][y]是由dp[par[y][x]][y]或dp[x][par[x][y]]转移过来的,转移多添一条边答案贡献+=sz[y][x]*sz[x][y]
说是树形dp,其实用枚举点对,降成了枚举每一条链的线性dp
代码
#include<bits/stdc++.h>
using namespace std;
const int N=3e3+10;
typedef long long ll;
int n,u,v;
ll dp[N][N],ans;
int sz[N][N],par[N][N];
vector<int>e[N];
//dp[x][y]代表x-y链上的最大值
//sz[x][y]代表以x为根时y的子树大小
//fa[x][y]代表以x为根时y的父亲
void dfs(int u,int fa,int rt)
{
sz[rt][u]=1;
par[rt][u]=fa;
for(int i=0;i<e[u].size();++i)
{
int v=e[u][i];
if(v==fa)continue;
dfs(v,u,rt);
sz[rt][u]+=sz[rt][v];
}
}
ll solve(int x,int y)
{
if(x==y)return 0;
if(~dp[x][y])return dp[x][y];
dp[x][y]=sz[x][y]*sz[y][x]+max(solve(x,par[x][y]),solve(par[y][x],y));
return dp[x][y];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;++i)
{
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
for(int i=1;i<=n;++i)
{
dfs(i,-1,i);
}
memset(dp,-1,sizeof dp);
for(int i=1;i<=n;++i)
{
for(int j=1;j<=n;++j)
{
ans=max(ans,solve(i,j));
}
}
printf("%lld\n",ans);
return 0;
}