题意:
给你一棵n个节点的树,每一个节点有一个权值,问你去掉至多k条边之后,
任意两个可以互相到达的点间的权值的差的最大值最小是多少。
解析:
这道题我一开始反着用贪心做,后来发现这道题根本不能从局部最优得到全局最优。
例如
4 1
20 11 9 0
1 2
2 3
3 4
这组样例k=1,k=2所删的边是完全不一样的。
直接用dp也不行,因为dp求答案的过程不满足树自底向上的性质。
求一条边能不能删,关乎以这条边两个端点为根的子树的差的最大值。
所以翻了题解,用二分答案+dp来做。
二分答案就是二分[0,max],max是原来树中最大值和最小值的差。
首先二分出以一个答案mid后,我们验证它就通过在差值最大为mid的条件下,
这棵树最少能被分成几个部分,即最少需要砍几刀使得每一个部分的最大值-最小值的差值(下面简称极差)都<=mid
对于这个问题,就是用dp来解的。这个dp我感觉又奇怪,但又很巧妙。
我自己也不是完全理解它的原理以及是如何想到的。
dp[x][i]表示在以x为根的子树(包括x)中,节点x所在部分的最小值是a[i]的条件下对于这棵子树最少需要砍几刀。
注意这个a[i]不一定是x的子树中的节点!
那么对于dp[x][i],如果a[x]-a[i]>mid||a[x]<a[i],那么这个就是非法的情况,就把dp[x][i]=INF,否则赋值为0
那么对于转移方程,对于一个节点x,我们只需要考虑他的儿子节点y的dp值就可以了。
定义 f[y]=min(dp[y][j]) j=1...n
如果dp[y][i]<f[y]+1,那么说明以y为根结点的子树,最优解就是把y节点归入最小值是a[i]的部分。
那么直接把x加入到这个部分就可以了,所以dp[x][i]+=dp[y][i]
否则,说明y的子树的最优解不是在y所在部分的最小值是a[i]的情况下,那么x和y就是两个部分的点,
对于不同的部分,我们就应该把这条边砍掉dp[x][i]+=f[y]+1
最后在mid条件下的最小砍的数量就是f[1]
最后总结一下,这道题的dp按我理解就是从终态出发,每一个点最后肯定是在一个部分里面的,并且这个部分也一定
是有最小值的。那么我们就用dp去枚举这个状态,然后这个dp用恰好满足树形的结构自底向上递推,
一条边的两个端点在同一部分就不用删,不在同一部分就删除。然后二分也挺巧妙地,二分答案,
用最多删除的边的数量k去验证答案。
最后我觉得奇怪的原因是,这个dp会把一些不存在的情况算出值
7 3
3 8 7 4 2 3 3
1 2
2 3
2 4
4 5
1 6
1 7
对于上面的情况最后dp[1][5]=2,但是我们画出图可以直到,1和5是根本不可能分到同一个部分的....
但是这些值好像又不会对正确答案产生影响.....我大概猜了一下,可能是切完2刀之后再把1的部分和5的部分
接起来,这样1部分的最小值就是a[5],但是改切的刀又是一定会切的并不会影响dp值,切出来合在一起也不会
违反极差<=mid的性质。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll INF = 0X3F3F3F3F3F3F3F3F;
const int MAXN = 1e3+100;
ll a[MAXN];
ll mim;
int dp[MAXN][MAXN];
int f[MAXN];
int n,k;
int mp[MAXN][MAXN];
void dfs(int x,int fa)
{
for(int i=1;i<=n;i++)
{
dp[x][i]=a[x]>=a[i]&&a[x]-a[i]<=mim?0:INF;
}
//dp[x][x]=0;
for(int i=1;i<=n;i++)
{
if(i==fa) continue;
if(mp[x][i])
{
int v=i;
dfs(v,x);
for(int j=1;j<=n;j++)
{
if(a[x]-a[j]<=mim&&a[x]>=a[j])
{
dp[x][j]+=min(dp[v][j],f[v]+1);
}
}
}
}
f[x]=INF;
for(int i=1;i<=n;i++) f[x]=min(f[x],dp[x][i]);
}
bool check(ll x)
{
mim=x;
dfs(1,0);
if(f[1]>k) return false;
return true;
}
int main()
{
scanf("%d%d",&n,&k);
ll mx,mi;
mx=-INF;
mi=INF;
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
mx=max(a[i],mx);
mi=min(a[i],mi);
}
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
mp[u][v]=mp[v][u]=1;
}
ll l=0;
ll r=mx-mi;
ll ans=INF;
while(l<r)
{
ll mid=(l+r)>>1;
if(check(mid)) ans=mid,r=mid;
else l=mid+1;
}
if(check(l)) ans=l;
printf("%lld\n",ans);
}