瞎扯:算法真的是无止境,从暴力到dp原本以为很神奇了,没想到还能优化dp,而且是把O(n^2)变成O(n),真是无fuck说。
引入:
我们来分析这么一个问题,给你n个数,要你把他们分成连续的若干块,
使得让他们的每段和的平方加起来最小.正常我们会想到的就是O(n^2)的dp,方程就是:
dp[i] = min(dp[j]+(sum[i]-sum[j])^2); (1<=j<i),sum[i]表示1到i的前缀和
让我们看看试着怎么去优化它。设(1<=k<j<i)如果用dp[j]更优于dp[k]那么就有:
dp[j]+(sum[i]-sum[j])^2 < dp[k]+(sum[i]-sum[k])^2
移项合并同类项得dp[j]+sum[j]^2-dp[k]-sum[k]^2 < 2sum[i]*(sum[j]-sum[k])
令f[x] = dp[x] + sum[x]^2,原式变为(f[j]-f[k])/(sum[j]-sum[k]) < 2sum[i].
左边的就变成了斜率似的东西了,而右边则是一个常数。可以以f[j]为纵坐标,sum[j]为横坐标在坐标系上表示出来.
总结:就是将dp方程只与j有关的部分看成y,与i和j有关的部分把j那部分看成x
根据上面的分析得如果满足上面的式子j会比k优,令g(k,j) = (f[j]-f[k])/(sum[j]-sum[k]),
如果有g(k,j) >= g(j,i),k<j<i<t,那么j这个点永远也不肯是最优的点,因为假如g(k,j) >= g(j,i) >= 2*sum[t]
那么此时k是最优的,如果是2*sum[t]>=g(k,j) >= g(j,i),i是最优的,g(k,j) >= 2*sum[t]>=g(j,i)这种情况i,k都比j优,j就跟不用考虑了。
就像这个斜率递增的下凸包:
反之j这个点还是有希望成为最优的,我们用一个单调队列来保存这些点将要插入的点放在队尾,在插入之前看当前队尾元素j如果满足上面的g(k,j) > g(j,i)条件,就将j删除,直到不满足为止将i插入.那么队列就满足从头到尾的g(a,b),g(b,c)...递增.
如果g(a,b)<2*sum[i],说明b比a优,删除队首a,直到g(a,b)>=2*sum[i]为止,这就维护了最优点的选择,就在队首的那个.
如果g(k,j)<g(j,i),且k是当前的最优选,前面说过j在这种情况是有希望的,因为2*sum[i]是递增的当,2*sum[i]>g(k,j)时,j就比k要优了,此时k就会被去掉了。
因为每个点只会被加入一次,删除一次,所以整个的操作就变成了O(n),就是这么神奇。
同理如果是求最大值,就变成了上凸包,大小关系变了,原理是一样的。
hdu 3507很好的练习题:
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<queue>
#define inf 0x3f3f3f3f
using namespace std;
const int mx = 5e5 + 10;
typedef long long ll;
int a[mx],n,m,dp[mx],sum[mx];
int que[mx],head,tail;
bool better(int x,int y,int i)
{
int a = sum[x]*sum[x]+dp[x]-sum[y]*sum[y]-dp[y];
int b = 2*sum[i]*(sum[x]-sum[y]);
return a < b;
}
bool small(int x,int y,int z)
{
int a = sum[y]*sum[y]+dp[y]-sum[x]*sum[x]-dp[x];
a *= (sum[z]-sum[y]);
int b = sum[z]*sum[z]+dp[z]-sum[y]*sum[y]-dp[y];
b *= (sum[y]-sum[x]);
return a >= b;
}
int main()
{
while(~scanf("%d%d",&n,&m)){
for(int i=1;i<=n;i++) scanf("%d",a+i);
for(int i=1;i<=n;i++) sum[i] = sum[i-1] + a[i];
head = tail = 0;
que[tail++] = 0;
for(int i=1;i<=n;i++){
while(head+1!=tail&&better(que[head+1],que[head],i))
head++;
dp[i]= (sum[i]-sum[que[head]])*(sum[i]-sum[que[head]])+dp[que[head]]+m;
while(head+1!=tail&&small(que[tail-2],que[tail-1],i))
tail--;
que[tail++] = i;
}
printf("%d\n",dp[n]);
}
return 0;
}