题意:给你n个点,n-1条边的树。每条边有一个权值w。给你一个值p。
1号节点为根节点。求1号点到所有节点的路径中 的 最小权值 的最大值。
权值计算:相当于把这条路径划分成若干段,每一段的权值为这一段的所有边的权值之和的平方。
每一段段尾如果不是目标城市,则需要支付p的费用。
思路(参考大佬博客):容易看出来是一个树形dp,并且有一个非常显然的状态转移方程:
,其中v是树上从u到根节点路径上的点。
但是显然这样的时间复杂度在树退化成链的时候会达到,需要想办法来进行优化。尝试进行变形:
如果状态v和w都可以转移到状态u,那么在这种情况下,从状态v转移会更优:
我们发现上式变成了一个斜率的形式。考虑将(dis[i], f[i])的点绘制出来,如果出现了下面的情况:
,那么通过枚举各种情况,我们可以分析出来 j 处必不可能是较优的点。
也就是说,有可能作为最优解进行转移的状态,它们的点必然是在一个下凸壳上的。
每次在得到一个新的状态的时候,由于dis[]的单调性,它的位置必然是在这个半凸壳的右端处。由于dis[]的单调性,f[]也是满足单调递增,这样就可以用一个单调队列来维护半凸壳上的点。
对于每个新的状态u,具体的维护方法为:
1. 检查队头的两个元素q[l]和q[l+1],通过上面的斜率检查,如果q[l+1]比q[l]更优,那么就把q[l]出队。
2. 直接取队头的元素为目标状态,进行状态转移,计算出f[u]。
3. 将u插入队尾。插入之前需要检查三个状态q[r-1], q[r], u是否满足斜率单调递增,若不满足则将q[r]出队。
这样就将整个DP的时间复杂度优化到了。
需要注意的是,由于每个节点可能有多个子节点,因此每次转移之后要将队尾恢复为原来的元素。
代码:
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f3f3f3f3fLL
using namespace std;
const int maxn=200010;
int n,m,k,x,y,s;
ll ans,tmp,cnt,p,aa;
ll zt[maxn],l,r;
struct node
{
int to,nex;
ll w;
}a[maxn];
int he[maxn],tot,q[maxn];
ll dp[maxn],dis[maxn];
void add(int u,int v,ll w)
{
a[tot].to=v;
a[tot].w=w;
a[tot].nex=he[u];
he[u]=tot++;
}
void init()
{
tot=r=0;l=1;
memset(he,-1,sizeof(he));
memset(dis,0,sizeof(dis));
ans=0;dp[1]=q[0]=0;
}
ll gety(int u,int v)
{
return dp[u]+dis[u]*dis[u]-dp[v]-dis[v]*dis[v];
}
ll getx(int u,int v){return dis[u]-dis[v];}
ll getdp(int u,int v){return dp[v]+p+(dis[u]-dis[v])*(dis[u]-dis[v]);}
void getpre(int u,int fa)
{
for(int i=he[u];i!=-1;i=a[i].nex)
{
int v=a[i].to;
if(v==fa) continue;
dis[v]=dis[u]+a[i].w;
// cout<<v<<" "<<dis[v]<<endl;
getpre(v,u);
}
}
void dfs(int u,int fa,int l,int r)
{
int pre=-1;
while(l<r&&gety(q[l+1],q[l])<=2*dis[u]*getx(q[l+1],q[l])) l++;
//cout<<dp[u]<<endl;
dp[u]=min(dp[u],getdp(u,q[l]));
while(l<r&&getx(u,q[r])*gety(q[r],q[r-1])>=gety(u,q[r])*getx(q[r],q[r-1])) r--;
pre=q[++r];q[r]=u;
ans=max(ans,dp[u]);
for(int i=he[u];i!=-1;i=a[i].nex)
{
int v=a[i].to;
if(v==fa) continue;
dfs(v,u,l,r);
}
if(pre!=-1) q[r]=pre;//恢复队尾
}
int main()
{
int T,cas=1;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&p);
init();
for(int i=0;i<n-1;i++)
{
scanf("%d%d%lld",&x,&y,&aa);
add(x,y,aa);
add(y,x,aa);
}
getpre(1,-1);
for(int i=1;i<=n;i++)
dp[i]=dis[i]*dis[i];
dfs(1,-1,1,0);
printf("%lld\n",ans);
// if(flag) puts("Yes"); else puts("No");
}
return 0;
}