【题解】Atcoder ARC#67 F-Yakiniku Restaurants

  觉得我的解法好简单,好优美啊QAQ

  首先想想暴力怎么办。暴力的话,我们就枚举左右端点,然后显然每张购物券都取最大的值。这样的复杂度是 \(O(n ^{2} m)\) 的。但是这样明显能够感觉到我们重复计算了很多东西,因为区间 \((l, r)\) 的答案与区间 \((l + 1, r)\) 的答案并不是独立的。

  我们可以考虑一下扫描线的做法。用一根扫描线从右往左扫左端点,同步维护所有以 \(l\) 为左端点的区间。由于我们现已经求出了所有以 \(l + 1\) 为左端点的区间答案(这里的答案指从 \(l -> r\) 中吃东西所能获得的最大权值),我们可以求出 \(l + 1, r\) 到 \(l, r\) 的增量变化,那么 \(ans[l][r] = ans[l + 1][r] + t\)。

  这个答案的增量显然只与 \(l\) 端点所能获得的权值有关。考虑第 j 个购物券,我们可以维护一个值单调递增的单调栈表示在每一个地点使用 j 购物券能获得最大权值的区间。弹栈的时候,我们用 \(val[i][j] - S[j][top].num\) 即可求出增量。这个增量会增加在 \(ans[i][j] -> ans[i][k]\) 这样的一个区间中。差分就可以解决了。

  感觉自己讲起来好混乱啊……すみません……

#include <bits/stdc++.h>
using namespace std;
#define maxn 5005
#define int long long
#define maxm 250 
int n, m, dis[maxn], val[maxn][maxm];
int Ans, ans[maxn][maxn], Q[maxn];

int read()
{
    int x = 0, k = 1;
    char c; c = getchar();
    while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); }
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * k;
}

struct node
{
    int num, id;
    node(int _id = 0, int _num = 0) { num = _num, id = _id; }
}S[maxm][maxn];

signed main()
{
    n = read(), m = read();
    for(int i = 2; i <= n; i ++) dis[i] = read() + dis[i - 1];
    for(int i = 1; i <= n; i ++) 
        for(int j = 1; j <= m; j ++) val[i][j] = read();
    for(int i = 1; i <= m; i ++) S[i][0].id = n + 1;
    for(int i = n; i >= 1; i --)
    {
        for(int j = 1; j <= m; j ++)
        {
            int top = Q[j];
            ans[i][i] += val[i][j]; ans[i][i + 1] -= val[i][j];
            while(top && S[j][top].num <= val[i][j]) 
            {
                int l = S[j][top].id, r = S[j][top - 1].id;
                int t = val[i][j] - S[j][top].num; 
                ans[i][l] += t, ans[i][r] -= t; 
                top --;
            }
            S[j][++ top] = node(i, val[i][j]);
            Q[j] = top;
        }
    }
    for(int i = n; i; i --)
    {
        for(int j = i; j <= n; j ++)
            ans[i][j] += ans[i][j - 1];
        for(int j = i; j <= n; j ++) ans[i][j] += ans[i + 1][j];
        for(int j = i; j <= n; j ++) 
            Ans = max(Ans, ans[i][j] - dis[j] + dis[i]); 
    }
    printf("%lld\n", Ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/twilight-sx/p/9916732.html