JZOJ ???? Function

没有传送门
题目大意

给定一个长度为 n 的数组 a ,定义一个二元函数 f ( x , y ) 1 x 10 9 1 y n

f ( x , y ) = { a y x = 1 f ( x 1 , y ) + a y y = 1  且  x 1 min ( f ( x 1 , y 1 ) , f ( x 1 , y ) ) + a y 其它情况

给定 m 个询问,每个询问形如 ( x i , y i ) ,你需要求出 f ( x i , y i ) 的值。不要求在线。

n , m 5 × 10 5

Limited Constraint:

  1. n 300 max { x } 10 3
  2. n 300
  3. n , m 10 5
考场上的思路

对于第一个 Limited Constraint,显然可以 O ( n x + m ) 记忆化搜索。

可以把这个函数抽象成一张 max { x } × n 的表格,每一行都是 a 数组。函数 f ( x , y ) 的意思就是:一开始你在 ( x , y ) ,每次可以往上走一格( x 1 ),或者往左上走一格( x 1 y 1 ),直到走到第一行。求一条路径使得权值之和最小,你只需要输出答案即可。

显然最优策略是一开始一路向左上走,然后一直向上走。可以理解成一开始向左走是为了走到一个权值小的格子,然后接下来就一直在这个权值小的格子上走;如果走这个格子不是最优的,那么肯定不会在这个格子上多走一步。

那么我们可以得到一个 O ( n m ) 的做法:对于每个询问,枚举开始向上走的位置。

把计算的式子写出来:

f ( x , y ) = min 1 j y , j > y x { s u m y s u m j + ( x ( y j ) ) × a j }

其中 s u m a 的前缀和。

考虑斜率优化。设决策 j 比决策 k 更优,那么有:

s u m y s u m j + ( x ( y j ) ) × a j < s u m y s u m k + ( x ( y k ) ) × a k

化简:
( s u m j s u m k ) ( j a j k a k ) > ( x y ) ( a j a k )

a j a k > 0 ,得:
( s u m j s u m k ) ( j a j k a k ) a j a k > x y

用 CDQ 分治解决横坐标不单调的斜率优化问题,有一个 O ( log n ) 。注意到,由于要求 j > y x ,因此有些状态的决策点不是 1 y ,而是 y x + 1 y 。在 CDQ 分治里用线段树加单调队列解决这个问题,时间复杂度 O ( n log 2 n ) ,可以得到 73 分。

注意斜率要用 long double ϵ ;归并排序时要稳定排序(用 <=)。

思路

我们设 g y , j ( x ) 表示以 ( x , y ) 为起点且最后在第 j 列一直向上走时,总代价是多少。显然这是一个关于 x 的一次函数。假设对于一个 y ,我们已经有了 g y , j ( 1 j y ) 的函数,显然答案为:

f ( x , y ) = min 1 j y { g y , j ( x ) }

对于相同的 y ,如何快速求得答案呢?我们维护这些一次函数形成的上凸壳,然后在上凸壳上二分即可,具体操作我们一会儿再说。时间复杂度为 O ( log n )

那么现在的问题是如何的得到对于所有 y 的所有 g 函数。考虑 g y , j g y + 1 , j ,对于所有的 j [ 1 , y ] ,它们的变化方式都是向右平移一个单位,向上平移 a y + 1 个单位(注意其实这是一条射线,向右平移后左边是空的)。然后我们加入了一条新的直线。这条直线截距为 0 ,因此当前斜率大于等于它的直线都不如它优。由于刚刚我们已经发现了在后面的变化中所有已经存在的直线只会平移,因此可以直接用单调栈维护这些直线。

二分时,检查中点对应直线与下一条直线的交点的横坐标 x 与当前查询的 x 的关系。如果 x x ,说明当前的 m i d 小了,否则 r 至多是 m i d

这里总结一下如何使用单调栈维护直线形成的凸壳(下面会具体讲前面略过的内容)。直线形成的凸壳必须满足下面的特征:

  1. 斜率单调。
  2. 相邻直线的交点的横坐标单调。

因此前面说的维护单调栈其实分成两步:第一步弹出斜率大于等于当前直线斜率的直线;第二步弹出使得交点不单调的直线:

while (stack[0] > 1)
{
    if (GetX(i, i, stack[stack[0]]) >= GetX(i, stack[stack[0] - 1], stack[stack[0]]))
        stack[0]--;
    else
        break;
}

这样查询时保证了相邻直线横坐标单调,因此可以根据相邻直线交点的横坐标和待查询的横坐标的关系进行二分。

注意到这里并没有处理 j y x 的不合法的情况,用到了非法解不会最优的思想。显然的是,对于一个决策 j j y 中的最小值,否则选择中间更小的那个更优。如果不合法,那么不合法减去的值一定比不合法加上的值小,因此不会遇到不合法的情况。

参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
typedef long long LL;
typedef unsigned long long ULL;
using std::cin;
using std::cout;
using std::endl;
typedef LL INT_PUT;
INT_PUT readIn()
{
    INT_PUT a = 0; bool positive = true;
    char ch = getchar();
    while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
    if (ch == '-') { positive = false; ch = getchar(); }
    while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
    return positive ? -a : a;
}
void printOut(INT_PUT x)
{
    char buffer[20]; int length = 0;
    if (x < 0) putchar('-'); else x = -x;
    do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
    do putchar(buffer[--length]); while (length);
    putchar('\n');
}

const int maxn = int(5e5) + 5;
int n, m;
int minx = INT_MAX, maxx = INT_MIN;
int a[maxn];
LL sum[maxn];
struct Query
{
    int x, y;
    void read()
    {
        x = readIn();
        y = readIn();
        minx = std::min(minx, x);
        maxx = std::max(maxx, x);
    }
} querys[maxn];

#define RunInstance(x) delete new x
struct brute1
{
    static const int maxN = 305;
    static const int maxx = 1005;
    LL f[maxx][maxN];

    LL DP(int x, int y)
    {
        LL& ans = f[x][y];
        if (~ans) return ans;
        if (x == 1) return ans = a[y];
        if (y == 1) return ans = DP(x - 1, y) + a[y];
        return ans = std::min(DP(x - 1, y - 1), DP(x - 1, y)) + a[y];
    }

    brute1()
    {
        std::memset(f, -1, sizeof(f));
        for (int i = 1; i <= m; i++)
            printOut(DP(querys[i].x, querys[i].y));
    }
};
struct brute2
{
    brute2()
    {
        for (int i = 1; i <= m; i++)
        {
            const Query& q = querys[i];
            LL ans = (LL)q.x * a[q.y];
            for (int j = 1; j < q.x && q.y - j >= 1; j++)
            {
                ans = std::min(ans, sum[q.y] - sum[q.y - j] +
                    (LL)(q.x - j) * a[q.y - j]);
            }
            printOut(ans);
        }
    }
};
struct brute3
{
    int N;
    struct Ins
    {
        int pos;
        int idx;
        Ins() {}
        Ins(int pos, int idx) : pos(pos), idx(idx) {}
        bool operator<(const Ins& b) const
        {
            if (pos != b.pos) return pos < b.pos;
            return idx < b.idx;
        }
    } inss[maxn * 2];
    int idx[maxn * 2];
    int temp[maxn * 2];

    static long double slope(int j, int k)
    {
        LL up = (sum[j] - sum[k]) -
            ((LL)j * a[j] - (LL)k * a[k]);
        if (!up) return 0;
        if (a[j] == a[k]) return up > 0 ? 1e100 : 1e-100;
        return (long double)up / (a[j] - a[k]);
    }
    static int dcmp(long double x)
    {
        const long double EPS = 1e-8;
        if (std::abs(x) <= EPS) return 0;
        return x < 0 ? -1 : 1;
    }

    struct SegTree
    {
        int deque[maxn * 20];
        static inline int code(int l, int r)
        {
            return (l + r) | (l != r);
        }
        int stamp;
        int size;
        int begin[maxn * 2];
        int head[maxn * 2];
        int tail[maxn * 2];
        int vis[maxn * 2];

        void helper(int l, int r)
        {
            int c = code(l, r);
            if (vis[c] != stamp)
            {
                vis[c] = stamp;
                head[c] = tail[c] = begin[c];
            }
        }
        int g_Pos, g_Val, g_L, g_R, X, Y;
        void insert_(int l, int r)
        {
            helper(l, r);
            int c = code(l, r);
            int& h = head[c];
            int& t = tail[c];
            while (t - h > 1 &&
                dcmp(slope(g_Pos, deque[t - 1]) -
                    slope(deque[t - 1], deque[t - 2])) > 0)
                t--;
            deque[t++] = g_Pos;

            if (l == r)
                return;
            int mid = (l + r) >> 1;
            if (g_Pos <= mid) insert_(l, mid);
            else insert_(mid + 1, r);
        }
        LL query_(int l, int r)
        {
            helper(l, r);
            if (g_L <= l && r <= g_R)
            {
                int c = code(l, r);
                int& h = head[c];
                int& t = tail[c];
                if (h == t) return LLONG_MAX;
                while (t - h > 1 &&
                    dcmp(slope(deque[h + 1], deque[h]) - g_Val) > 0)
                    h++;

                int j = deque[h];
                return sum[Y] - sum[j] + (LL)(X - (Y - j)) * a[j];
            }
            int mid = (l + r) >> 1;
            LL ret = LLONG_MAX;
            if (g_L <= mid) ret = std::min(ret, query_(l, mid));
            if (g_R > mid) ret = std::min(ret, query_(mid + 1, r));
            return ret;
        }

    public:
        SegTree() : stamp(), size(), vis() {}
        void build(int l, int r)
        {
            begin[code(l, r)] = size;
            size += r - l + 1;
            if (l == r)
                return;
            int mid = (l + r) >> 1;
            build(l, mid);
            build(mid + 1, r);
        }
        void clear() { stamp++; }
        void insert(int pos)
        {
            g_Pos = pos;
            insert_(1, n);
        }
        LL query(int l, int r, int val, int x, int y)
        {
            g_L = l;
            g_R = r;
            g_Val = val;
            X = x;
            Y = y;
            return query_(1, n);
        }
    } st;

    bool comp(const Ins& x, const Ins& y)
    {
        if (x.idx && y.idx)
            return querys[x.idx].x - querys[x.idx].y >
            querys[y.idx].x - querys[y.idx].y;
        if (x.idx)
            return false;
        if (y.idx)
            return true;
        return a[x.pos] <= a[y.pos];
    }

    LL ans[maxn];
    void cdq(int l, int r)
    {
        if (l == r)
        {
            return;
        }
        int mid = (l + r) >> 1;
        cdq(l, mid);
        cdq(mid + 1, r);

        st.clear();
        for (int i = l; i <= mid; i++)
        {
            if (inss[idx[i]].idx)
                continue;
            st.insert(inss[idx[i]].pos);
        }
        for (int i = mid + 1; i <= r; i++)
        {
            const Ins& ins = inss[idx[i]];
            if (!ins.idx)
                continue;
            const Query& q = querys[ins.idx];

            ans[ins.idx] = std::min(ans[ins.idx],
                st.query(std::max(1, q.y - q.x + 1), q.y, q.x - q.y, q.x, q.y));
        }

        int i = l;
        int j = mid + 1;
        int k = l;
        while (k <= r)
        {
            if (j > r || (i <= mid && comp(inss[idx[i]], inss[idx[j]])))
                temp[k++] = idx[i++];
            else
                temp[k++] = idx[j++];
        }
        for (i = l; i <= r; i++)
            idx[i] = temp[i];
    }
    brute3() : N()
    {
        for (int i = 1; i <= n; i++)
            inss[++N] = Ins(i, 0);
        for (int i = 1; i <= m; i++)
            inss[++N] = Ins(querys[i].y, i);
        std::sort(inss + 1, inss + 1 + N);
        for (int i = 1; i <= N; i++)
            idx[i] = i;

        st.build(1, n);
        for (int i = 1; i <= m; i++)
            ans[i] = LLONG_MAX;
        cdq(1, N);

        for (int i = 1; i <= m; i++)
            printOut(ans[i]);
    }
};
struct work
{
    int idx[maxn];

    static bool comp(const int& x, const int& y)
    {
        return querys[x].y < querys[y].y;
    }

    int stack[maxn];
    LL ans[maxn];

    LL calc(int x, int y, int j)
    {
        return sum[y] - sum[j] + (LL)(x - (y - j)) * a[j];
    }

    long double GetX(int i, int p1, int p2)
    {
        return (long double)((sum[p1] - sum[p2]) +
            (LL)(i - p1) * a[p1] - (LL)(i - p2) * a[p2]) / (a[p1] - a[p2]);
    }

    work()
    {
        for (int i = 1; i <= m; i++)
            idx[i] = i;
        std::sort(idx + 1, idx + 1 + m, comp);

        stack[0] = 0;
        int cnt = 1;
        for (int i = 1; i <= n && cnt <= m; i++)
        {
            while (stack[0] && a[stack[stack[0]]] >= a[i])
                stack[0]--;
            while (stack[0] > 1)
            {
                if (GetX(i, i, stack[stack[0]]) >= GetX(i, stack[stack[0] - 1], stack[stack[0]]))
                    stack[0]--;
                else
                    break;
            }
            stack[++stack[0]] = i;

            while (cnt <= m && querys[idx[cnt]].y == i)
            {
                const Query& q = querys[idx[cnt]];
                int l = 1, r = stack[0];
                while (r - l > 0)
                {
                    int mid = (l + r) >> 1;
                    int p1 = stack[mid];
                    int p2 = stack[mid + 1];
                    if (GetX(i, p1, p2) >= q.x)
                        l = mid + 1;
                    else
                        r = mid;
                }
                ans[idx[cnt]] = calc(q.x, q.y, stack[l]);
                cnt++;
            }
        }
        for (int i = 1; i <= m; i++)
            printOut(ans[i]);
    }
};

void run()
{
    n = readIn();
    for (int i = 1; i <= n; i++)
        sum[i] = sum[i - 1] + (a[i] = readIn());
    m = readIn();
    for (int i = 1; i <= m; i++)
        querys[i].read();

    if (n <= 300 && maxx <= 1000)
        RunInstance(brute1);
    else if (n <= 300)
        RunInstance(brute2);
    else if (n <= int(1e5) && m <= int(1e5))
        RunInstance(brute3);
    else
        RunInstance(work);
}

int main()
{
#ifndef LOCAL
    freopen("function.in", "r", stdin);
    freopen("function.out", "w", stdout);
#endif
    run();
    return 0;
}
总结

set 是不能拿来写斜率优化用的平衡树的(虽然我们没有写过)!不要浪费时间。而且树套树
能做的 CDQ 分治也一定能做。

猜你喜欢

转载自blog.csdn.net/lycheng1215/article/details/80990593