题目链接
是一道好题了……一直在WA,不知道是哪里写彪了,后来,没想到,竟然是没考虑到前面所有的数都和在一起才是最小值的可能,就举个例子吧:
5 10000000
0 0 0 0 0
ans:10000000
5 100000000
0 2 3 4 5
ans:100000196
但是,我们若是没关注到这个问题的话,初始的值会极其的大,会多出来一个M的值,所以要考虑到初始值的问题,就是给队列的头放成0,这样的话,就可以解决上面的问题,不然的话,就会不断的往上面走,但是初始会多出来一个M。
剩下的就是斜率优化DP的模板了,但具体讲一下推理的公式:
对于dp[i] = min( (sum[i] - sum[j])^2 + M + dp[j] );(j < i)
dp[i] = min( sum[i]^2 + sum[j]^2 - 2*sum[i]*sum[j] + M + dp[j] );
为了寻找到i之前最优的j,我们假设有这样的存在k<j<i,但是j是i之前的最优解,于是有:
sum[i]^2 + sum[j]^2 - 2*sum[i]*sum[j] + M + dp[j] < sum[i]^2 + sum[k]^2 - 2*sum[i]*sum[k] + M + dp[k];
进行化简,会得到:
sum[j]^2 - 2*sum[i]*sum[j] + dp[j] < sum[k]^2 - 2*sum[i]*sum[k] + dp[k];
我们再把带有i的值的移到一边去,于是有:
sum[j]^2 + dp[j] - ( sum[k]^2 + dp[k] ) < 2 * sum[i] * ( sum[j] - sum[k] );
于是,会得到一个可以假设的斜截式:
( sum[j]^2 + dp[j] - ( sum[k]^2 + dp[k] ) ) / (2 * ( sum[j] - sum[k] )) < sum[i];
那么,假设这样的斜率表达式:
令yj = sum[j]^2 + dp[j]; 令yk = sum[k]^2 + dp[k];
令xj = 2 * sum[j]; 令xk = 2 * sum[k];
于是,原式等于:(yj - yk) / (xj - xk) < sum[i];
此时,就是j比k更优的情况,我们根据这个不等式,就可以列写出单调队列了,
在0~j之中,j一定是最优解,因为sum记录的是前缀和,所以,它是单调递增的,存在sum[i+x] > sum[i] > (yj - yk) / (xj - xk);
那么,接下来只需要考虑i+1~i+x之间的点了,若是之间有这样的更优解的话,那么,j这个节点就不需要再保存了,我们把jpop()就是了,此时,点i入队,比较i点与队列尾的点q[tail]、q[tail-1]的斜率:K1、K2;
若是K1<=K2,则将q[tail]去除吧,之后是肯定不会用这个点了的,因为,它没有它的前一个节点优效,我们这样逐一比较,直到队列的节点不剩下的时候(其实那会应该剩下一个我们最初放入的0号节点,保证了可能是从第一位开始向后加的情况的元素),或者是K1>K2的时候,我们结束循环,并且继续下一个元素。
对了,head处理之后,会有dp[]的赋值,此时取队首即可,因为那就是最优解了。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define SonG_y main
#define MP(x, y) make_pair(x, y)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int maxN = 5e5 + 7;
int N, M;
ll sum[maxN] = {0}, dp[maxN] = {0}; int q[maxN] = {0};
inline ll Det_Y(int i, int j) { return sum[i] * sum[i] + dp[i] - ( sum[j] * sum[j] + dp[j] ); }
inline ll Det_X(int i, int j) { return 2 * (sum[i] - sum[j]); }
inline void solve()
{
int head = 1, tail = 0;
q[++tail] = 0;
for(int i=1; i<=N; i++)
{
while(head < tail && Det_Y(q[head+1], q[head]) <= sum[i] * Det_X(q[head+1], q[head]) ) head++;
dp[i] = (sum[i] - sum[q[head]]) * (sum[i] - sum[q[head]]) + M + dp[q[head]];
while(head < tail && ( Det_Y(i, q[tail]) * Det_X(i, q[tail-1]) ) <= ( Det_Y(i, q[tail-1]) * Det_X(i, q[tail] )) ) tail--;
q[++tail] = i;
}
printf("%lld\n", dp[N]);
}
inline void init()
{
memset(dp, 0, sizeof(dp));
memset(sum, 0, sizeof(sum));
memset(q, 0, sizeof(q));
}
int SonG_y()
{
while(scanf("%d%d", &N, &M)!=EOF)
{
init();
for(int i=1; i<=N; i++)
{
scanf("%lld", &sum[i]);
sum[i] += sum[i-1];
}
solve();
}
return 0;
}