斜率优化dp笔记

瞎扯:算法真的是无止境,从暴力到dp原本以为很神奇了,没想到还能优化dp,而且是把O(n^2)变成O(n),真是无fuck说。

引入:

    我们来分析这么一个问题,给你n个数,要你把他们分成连续的若干块,

使得让他们的每段和的平方加起来最小.正常我们会想到的就是O(n^2)的dp,方程就是:

                     dp[i] = min(dp[j]+(sum[i]-sum[j])^2); (1<=j<i),sum[i]表示1到i的前缀和

让我们看看试着怎么去优化它。设(1<=k<j<i)如果用dp[j]更优于dp[k]那么就有:

dp[j]+(sum[i]-sum[j])^2 < dp[k]+(sum[i]-sum[k])^2

移项合并同类项得dp[j]+sum[j]^2-dp[k]-sum[k]^2 < 2sum[i]*(sum[j]-sum[k])

令f[x] = dp[x] + sum[x]^2,原式变为(f[j]-f[k])/(sum[j]-sum[k]) < 2sum[i].

左边的就变成了斜率似的东西了,而右边则是一个常数。可以以f[j]为纵坐标,sum[j]为横坐标在坐标系上表示出来.

总结:就是将dp方程只与j有关的部分看成y,与i和j有关的部分把j那部分看成x

根据上面的分析得如果满足上面的式子j会比k优,令g(k,j) = (f[j]-f[k])/(sum[j]-sum[k]),

如果有g(k,j) >= g(j,i),k<j<i<t,那么j这个点永远也不肯是最优的点,因为假如g(k,j) >= g(j,i) >= 2*sum[t]

那么此时k是最优的,如果是2*sum[t]>=g(k,j) >= g(j,i),i是最优的,g(k,j) >= 2*sum[t]>=g(j,i)这种情况i,k都比j优,j就跟不用考虑了。

就像这个斜率递增的下凸包:

反之j这个点还是有希望成为最优的,我们用一个单调队列来保存这些点将要插入的点放在队尾,在插入之前看当前队尾元素j如果满足上面的g(k,j) > g(j,i)条件,就将j删除,直到不满足为止将i插入.那么队列就满足从头到尾的g(a,b),g(b,c)...递增.

如果g(a,b)<2*sum[i],说明b比a优,删除队首a,直到g(a,b)>=2*sum[i]为止,这就维护了最优点的选择,就在队首的那个.

如果g(k,j)<g(j,i),且k是当前的最优选,前面说过j在这种情况是有希望的,因为2*sum[i]是递增的当,2*sum[i]>g(k,j)时,j就比k要优了,此时k就会被去掉了。

因为每个点只会被加入一次,删除一次,所以整个的操作就变成了O(n),就是这么神奇。

同理如果是求最大值,就变成了上凸包,大小关系变了,原理是一样的。

hdu 3507很好的练习题:

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<queue>
#define inf 0x3f3f3f3f
using namespace std;
const int mx = 5e5 + 10;
typedef long long ll;
int a[mx],n,m,dp[mx],sum[mx];
int que[mx],head,tail;
bool better(int x,int y,int i)
{
	int a = sum[x]*sum[x]+dp[x]-sum[y]*sum[y]-dp[y];
	int b = 2*sum[i]*(sum[x]-sum[y]);
	return a < b;
}
bool small(int x,int y,int z)
{
	int a = sum[y]*sum[y]+dp[y]-sum[x]*sum[x]-dp[x];
	a *= (sum[z]-sum[y]);
	int b = sum[z]*sum[z]+dp[z]-sum[y]*sum[y]-dp[y];
	b *= (sum[y]-sum[x]);
	return a >= b;
}
int main()
{
	while(~scanf("%d%d",&n,&m)){
		for(int i=1;i<=n;i++) scanf("%d",a+i);
		for(int i=1;i<=n;i++) sum[i] = sum[i-1] + a[i];
		head = tail = 0;
		que[tail++] = 0;
		for(int i=1;i<=n;i++){
			while(head+1!=tail&&better(que[head+1],que[head],i)) 
			head++;
			dp[i]= (sum[i]-sum[que[head]])*(sum[i]-sum[que[head]])+dp[que[head]]+m;
			while(head+1!=tail&&small(que[tail-2],que[tail-1],i))
			tail--;
			que[tail++] = i;
		}
		printf("%d\n",dp[n]);
	}
	return 0;
} 

猜你喜欢

转载自blog.csdn.net/a1214034447/article/details/81071763
今日推荐