[luogu3628][bzoj1911][APIO2010]特别行动队【动态规划+斜率优化DP】

题目描述

给你一个数列,让你将这个数列分成若干段,使其每一段的和的\(a \times sum^2 + b \times sum + c\)的总和最大。

分析

算是一道斜率优化的入门题。
首先肯定是考虑\(O(n^2)\)的暴力DP。
定义状态\(f[i]\)表示最后一段的结尾是\(i\)的最大答案。
那么枚举j,得到转移方程为\(f[i]=max(f[i],f[j]+a\times (sum[i]-sum[j])^2+b\times(sum[i]-sum[j])+c\)
注意这里的转移方程不是\(sum[i]-sum[j-1]\)而是\(sum[i]-sum[j]\),因为j是属于前一段的,所以不能算j-1这一个格子。
40分蜜汁错误和全部剩下全部T飞。

#include <bits/stdc++.h>
#define ll long long
#define ms(a, b) memset(a, b, sizeof(a))
#define inf 0x3f3f3f3f
#define N 1000005
using namespace std;
template <typename T>
inline void read(T &x) {
    x = 0; T fl = 1;
    char ch = 0;
    while (ch < '0' || ch > '9') {
        if (ch == '-') fl = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    x *= fl;
}
ll sqr(ll x) {
    return x * x;
}
int n;
ll a, b, c;
ll x[N], sum[N], f[N];
int main() {
    read(n); 
    read(a); read(b); read(c);
    for (int i = 1; i <= n; i ++)  {
        read(x[i]);
        sum[i] = sum[i - 1] + x[i];
    }
    for (int i = 1; i <= n; i ++) 
        for (int j = 0; j < i; j ++) 
            f[i] = max(f[i], f[j] + sqr(sum[i] - sum[j]) * a + (sum[i] - sum[j]) * b + c);
    printf("%lld\n", f[n]);
    return 0;
}

很显然这不是正解,那么我们就考虑优化DP。
考虑斜率优化。按照斜率优化的标准套路。
假设\(0<=k<j<i\)时,j的状态比k要优。。
那么很明显就得到了以下的式子,再将其化简:
\[f_j+a\times(sum_i-sum_j)^2+b\times(sum_i-sum_j)+c>=f_k+a\times a\times(sum_i-sum_j)^2+b\times(sum[i]-sum[j])+c\]
以下的所有操作都是初中内容。(可能会跳步,但是看得懂)
\[f_j-2asum_isum_j+asum_j^2-bsum_j>=f_k-2asum_isum_k+asum_k^2-bsum_k\]
\[\frac{(f_j+asum_j^2-bsum_j)-(f_k+asum_k^2-bsum_k)}{sum_j-sum_k}>=2sum_i\]
可以发现:左边单调递增,右边单调递减。那么维护斜率上凸包。

#include <bits/stdc++.h>
#define ll long long
#define ms(a, b) memset(a, b, sizeof(a))
#define inf 0x3f3f3f3f
#define N 1000005
#define db double 
using namespace std;
template <typename T>
inline void read(T &x) {
    x = 0; T fl = 1;
    char ch = 0;
    while (ch < '0' || ch > '9') {
        if (ch == '-') fl = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    x *= fl;
}
ll sqr(ll x) {
    return x * x;
}
int n;
ll a, b, c;
ll x[N], sum[N], f[N];
int q[N];
ll get_x(int i) {
    return sum[i];
}
ll get_y(int i) {
    return f[i] + a * sqr(sum[i]) - b * sum[i];
}
db get_slope(int i, int j) {
    return (1.0 * (get_y(j) - get_y(i))) / (1.0 * (get_x(j) - get_x(i)));
}
int main() {
    read(n); 
    read(a); read(b); read(c);
    for (int i = 1; i <= n; i ++)  {
        read(x[i]);
        sum[i] = sum[i - 1] + x[i];
    }
    int h = 0, t = 0;
    for (int i = 1; i <= n; i ++) {
        while (h < t && get_slope(q[h] , q[h + 1]) >= 2.0 * a * sum[i]) h ++;
        int j = q[h];
        f[i] = f[j] + a * sqr(sum[i] - sum[j]) + b * (sum[i] - sum[j]) + c;
        while (h < t && get_slope(q[t - 1], q[t]) <= get_slope(q[t], i)) t --;
        q[++ t] = i;
    }
    printf("%lld\n", f[n]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/chhokmah/p/10593218.html
今日推荐