区间DP入门及平行四边形优化

区间DP, 指的就是对区间的DP, 主要的思想是依旧是最优子结构和无后效性的确保, 一般思路就是先对小区间进行操作得到最优解, 然后通过小区间的最优解来得到大区间的最优解。

利用dp[i][j]数组来表示从 i 到 j 区间合并的最优值。

这里笔者给出基本的区间DP模板帮助理解:

//n是区间长度,dp[i][j]存从i 到 j 区间合并的最优值
//w[i][j]表示从i 到 j的花费
for(i = 1;i <= n;i++)
    dp[i][i] = 初始值;
for(len = 2;len <= n;len++){//len选择区间长度
    for(i = 1;i <= n;i++){//枚举起点
        j = i + len - 1;//合并终点
        if(j > n)break;//不可越界
        for(k = i;k < j;k++)//枚举分割点,寻找最优分割
            dp[i][j] = max(dp[i][j], dp[i][k] + dp[k + 1][j] + w[i][j]);//状态转移
    }
}
    

学校OJ上有一个模板题:石子合并

石子合并1

Time Limit: 1000 MS    Memory Limit: 32768 KB
Total Submission(s): 143    Accepted Submission(s): 65

Description

有n堆石子排成一行,每次选择相邻的两堆石子,将其合并为一堆,记录该次合并的得分为两堆石子个数之和。已知每堆石子的石子个数,求当所有石子合并为一堆时,最小的总得分。

Input

第一行一个整数n(1 <= n <= 200),表示石子堆数; 第二行n个整数a(1 <= a <= 100),表示每堆石子的个数。

Output

一个整数,表示最小总得分。

Sample Input

5
7 6 5 7 100

Sample Output

175

AC代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
int main()
{
    int i, j, n, k, len, st;
    int sum[205] = {0};//sum[i]存合并前n石子的花费
    int dp[205][205];
    memset(dp, 0x3f, sizeof(dp));//初始化为较大数值以更新
    scanf("%d", &n);
    for(i = 1;i <= n;i++){
        scanf("%d", &st);
        dp[i][i] = 0;//合并一堆石子不需要花费
        sum[i] = sum[i - 1] + st;//求sum数组
    }
    for(len = 2;len <= n;len++){
        for(i = 1;i < n;i++){
            j = i + len - 1;//计算区间终点
            if(j > n)//越界跳出
                break;
            for(k = i;k < j;k++)
                dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j] + sum[j] - sum[i - 1]);
        }
    }
    printf("%d\n", dp[1][n]);
    return 0;
}

经验比较多的朋友可能会清醒地看到那个三层循环, 没错,这个算法的时间复杂度属于n ^ 3, 因为题目数据比较小所以用时也不是很高,但题目数据较大时是不可以接受的。但第三层寻找最优分割点的时候会有许多重复的过程, 这里我们可以用一个s[i][j]数组记录下从i 到 j 最优分割点的下标, 在下次寻找时减少寻找次数, 这样就可以将时间降低到 n ^ 2的复杂度, 就是平行四边形优化。

扫描二维码关注公众号,回复: 3598151 查看本文章
​

//n是区间长度,dp[i][j]存从i 到 j 区间合并的最优值
//w[i][j]表示从i 到 j的花费, s[i][j]记录从i 到 j的最优分割点
for(i = 1;i <= n;i++){
    dp[i][i] = 初始值;
    s[i][i] = i;
}
for(len = 2;len <= n;len++){//len选择区间长度
    for(i = 1;i <= n;i++){//枚举起点
        j = i + len - 1;//合并终点
        if(j > n)break;//不可越界
        for(k = s[i][j - 1];k < s[i + 1][j];k++)//在最优分割点范围内枚举分割点
            if(dp[i][j] > dp[i][k] + dp[k + 1][j] + w[i][j]){
                dp[i][j] = dp[i][k] + dp[k + 1][j] + w[i][j];
                s[i][j] = k;//更新最佳分割点
            }
    }
}
    

[点击并拖拽以移动]
​

优化后的代码:

#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
int main()
{
    int i, j, k, n, len, m;
    int sum[205];
    int s[205][205], dp[205][205];
    memset(dp, 0x3f, sizeof(dp));
    scanf("%d", &n);
    sum[0] = 0;
    for(i = 1;i <= n;i++){
        scanf("%d", &m);
        sum[i] = sum[i - 1] + m;
        dp[i][i] = 0;
        s[i][i] = i;
    }
    for(len = 2;len <= n;len++){
        for(i = 1;i < n;i++){
            j = i + len - 1;
            if(j > n)break;
            for(k = s[i][j - 1];k <= s[i + 1][j];k++){
                if(dp[i][k] + dp[k + 1][j] + sum[j] - sum[i - 1] < dp[i][j]){
                    dp[i][j] = dp[i][k] + dp[k + 1][j] + sum[j] - sum[i - 1];
                    s[i][j] = k;
                }
            }
        }
    }
    printf("%d\n", dp[1][n]);
    return 0;
}

数据不够大所以时间缩短得也不是很多看不出来优化后的优势, 不过读者可以去做一下数据较大的题目自行感受一下时间的优化~

猜你喜欢

转载自blog.csdn.net/LxcXingC/article/details/81291901