[THUSC2016]成绩单 题解

传送门 LOJ2292

题目大意:给一个长度为 n n 的数列 { w n } \{w_n\} ,每次抽走一段连续的子段,直到全部抽完。如果一共抽了 k k 次,总代价为 a × k + b × i = 1 k ( max i min i ) 2 a\times k+b\times\sum\limits_{i=1}^k(\max_i-\min_i)^2 ,其中 a a b b 是输入参数, max i \max_i min i \min_i 分别表示第 i i 次抽取的数的最大值和最小值。请最小化总代价。

n 50 n\leq 50

考虑到这题的抽取方式:每次从中间抽取一段,然后两边的又会拼起来,所以区别于序列划分类问题,前缀DP的状态不够清晰无法转移,可以想到区间DP。

先设 g ( i , j ) g(i,j) 表示将 [ i , j ] [i,j] 这段区间的数全部消掉的最小代价,那么最终答案就是 g ( 1 , n ) g(1,n) 。但是这样会有问题:我们需要知道每次消去的 max i \max_i min i \min_i 是多少,才能方便转移。那么我们可以这样做:首先对 w i w_i 离散化,假设 t m p i tmp_i 表示原数组 w w 中第 i i 小的数;对于一个 g ( i , j ) g(i,j) ,我们可以枚举值域区间 [ l , r ] [l,r] ,表示将 [ i , j ] [i,j] 这段区间里的数全部消掉之前最后一次消去的数都在值域范围 [ l , r ] [l,r] 当中。那么最后一次消去显然会产生 a + b × ( t m p r t m p l ) 2 a+b\times(tmp_r-tmp_l)^2 的代价。那么接下来的问题就是:将区间 [ i , j ] [i,j] 中的数消到只剩下值域范围在 [ l , r ] [l,r] 中的数,最小代价是多少,不妨设其为 f ( i , j , l , r ) f(i,j,l,r)

这样一来我们 g ( i , j ) g(i,j) 的转移方程就有了: g ( i , j ) = min l r { f ( i , j , l , r ) + a + b × ( t m p r t m p l ) 2 } g(i,j)=\min\limits_{l\leq r}\{f(i,j,l,r)+a+b\times(tmp_r-tmp_l)^2\}

下面考虑 f ( i , j , l , r ) f(i,j,l,r) 的转移。一个显然的思路是 f ( i , j , l , r ) = min i k < j { f ( i , k , l , r ) + f ( k + 1 , j , l , r ) } f(i,j,l,r)=\min\limits_{i\leq k<j}\{f(i,k,l,r)+f(k+1,j,l,r)\} ,然而很不幸这样做会漏掉一些情况。

正确的做法是:首先我们可以找到 [ i , j ] [i,j] 区间里左边第一个不在值域范围 [ l , r ] [l,r] 中的数的位置 p p ,以及右边第一个不在值域范围 [ l , r ] [l,r] 中的数的位置 q q ,如果存在这样的区间 [ p , q ] [p,q] (如果不存在当然就不用管了),那么

f ( i , j , l , r ) = min ( g ( p , q ) , min i k < j { f ( i , k , l , r ) + f ( k + 1 , j , l , r ) } ) f(i,j,l,r)=\min(g(p,q),\min\limits_{i\leq k<j}\{f(i,k,l,r)+f(k+1,j,l,r)\})

好了,这两个转移方程出来之后,剩下的就是区间DP套路了。当然是枚举区间长度再枚举左端点做DP就行了。

时间复杂度 O ( n 5 ) O(n^5)

#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>

template <typename T> inline void read(T& x) {
    int f = 0, c = getchar(); x = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
    if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
    read(x); read(args...); 
}
template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

const int maxn = 57, inf = 1e9;
int w[maxn], tmp[maxn];
int f[maxn][maxn][maxn][maxn], g[maxn][maxn];
int n, a, b, all;

int main() {
    read(n, a, b);
    for (int i = 1; i <= n; ++i) read(w[i]), tmp[++all] = w[i];
    std::sort(tmp + 1, tmp + all + 1);
    all = std::unique(tmp + 1, tmp + all + 1) - (tmp + 1);
    for (int i = 1; i <= n; ++i)
        w[i] = std::lower_bound(tmp + 1, tmp + all + 1, w[i]) - tmp;
    for (int i = 1; i <= n; ++i) {
        g[i][i] = a;
        for (int l = 1; l <= all; ++l)
            for (int r = l; r <= all; ++r)
                if (w[i] > r || w[i] < l)
                    f[i][i][l][r] = a;
    }
    for (int len = 2; len <= n; ++len)
        for (int i = 1, j = i + len - 1; j <= n; ++i, ++j) {
            g[i][j] = inf;
            for (int l = 1; l <= all; ++l)
                for (int r = l; r <= all; ++r) {
                    int p = 0, q = 0;
                    for (int t = i; t <= j; ++t)
                        if (w[t] > r || w[t] < l) { p = t; break; }
                    for (int t = j; t >= i; --t)
                        if (w[t] > r || w[t] < l) { q = t; break; }
                    if (p && q) f[i][j][l][r] = g[p][q];
                    else f[i][j][l][r] = inf;
                    for (int k = i; k < j; ++k)
                        chkmin(f[i][j][l][r], f[i][k][l][r] + f[k + 1][j][l][r]);
                    chkmin(g[i][j], f[i][j][l][r] + a + b * (tmp[r] - tmp[l]) * (tmp[r] - tmp[l]));
                }
        }
    writeln(g[1][n]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39677783/article/details/86898654