SDUST 2018校赛 问题 E: 搬砖——splay

题目描述

舒荷·王,最近开始了他的搬砖生活。搬砖生活中,他要先将砖排好。已知他的面前有N柱砖,因任务所需,至少要连续X个的砖柱高度是一样的。他可以选择从一个砖柱扔掉一块砖,也可加上一块砖(假设他有无限的砖块储备),每次加砖或者扔砖都要消耗一点体力。他想知道他该怎么做,才能让自己消耗的体力最小。

输入

第一行给出两个数N和X,代表多少柱砖和连续多少个砖块,第二行有N个数,代表砖柱的高度h。

其中1 <= X <= N <= 100000, 0 <= H <= 1000000

输出

一行代表最小体力

样例输入

3 21 2 1

样例输出

1

思路:虽然题目中说的是至少x,不过很明显可以直接按照x来做,从左到右枚举每个大小为x的区间,对于每个区间,很明显把所有堆的高度变成他们的中位数是最优的,可以用区间第k大从spaly中获得中位数(由于splay的性质顺便把这个中位数装到了根上),然后根据左右子树的size和sum来进行求解,具体参考代码中的solve函数

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 10;
const long long INF = 1e12;
struct Splay {
    int sz, root, f[maxn], ch[maxn][2], cnt[maxn], size[maxn];
    long long key[maxn], sum[maxn];
    void clear() { sz = root = 0; }
    void newnode(int x, int fa, int v) {
        f[x] = fa; ch[x][0] = ch[x][1] = 0; key[x] = sum[x] = v; cnt[x] = size[x] = 1;
    }
    void delnode(int x) {
        f[x] = ch[x][0] = ch[x][1] = key[x] = sum[x] = cnt[x] = size[x] = 0;
    }
    void pushup(int x) {
        if (!x) return;
        sum[x] = key[x]*cnt[x];
        size[x] = cnt[x];
        if (ch[x][0]) sum[x] += sum[ch[x][0]], size[x] += size[ch[x][0]];
        if (ch[x][1]) sum[x] += sum[ch[x][1]], size[x] += size[ch[x][1]];
    }
    int witch(int x) { return ch[f[x]][1] == x; }
    void rotate(int x) {
        int y = f[x], z = f[y], px = witch(x), py = witch(y);
        ch[y][px] = ch[x][px^1]; f[ch[y][px]] = y;
        ch[x][px^1] = y; f[y] = x;
        f[x] = z; if (z) ch[z][py] = x;
        pushup(y); pushup(x);
    }
    void splay(int x) {
        for (int fa; (fa = f[x]); rotate(x))
            if (f[fa]) rotate(witch(x) == witch(fa) ? fa : x);
        root = x;
    }
    void insert(int v) {
        if (!root) { newnode(++sz, 0, v); root = sz; return; }
        int fa = 0, x = root;
        while (true) {
            if (x == 0) {
                newnode(++sz, fa, v); ch[fa][key[fa] < v] = sz;
                splay(sz); break;
            }
            if (key[x]==v) {
                sum[x] += v; cnt[x]++; size[x]++;
                splay(x); break;
            }
            fa = x;
            x = ch[x][key[x]<v];
        }
    }
    void delet(int v) {
        int t = rank(v);
        if (cnt[root] > 1) { sum[root] -= v; cnt[root]--; size[root]--; return; }
        if (!ch[root][0]&&!ch[root][1]) { delnode(root); root = 0; return; }
        if (!ch[root][0]) {
            t = root; root = ch[root][1]; f[root] = 0; delnode(t); return;
        }
        if (!ch[root][1]) {
            t = root; root = ch[root][0]; f[root] = 0; delnode(t); return;
        }
        int p = pre(); t = root;
        splay(p); f[ch[t][1]] = root; ch[root][1] = ch[t][1]; delnode(t);
        pushup(root);
    }
    int rank(int v) {
        int ans = 0, x = root;
        while (true) {
            if (!x) return -1;
            if (v < key[x]) x = ch[x][0];
            else {
                ans += (ch[x][0]?size[ch[x][0]]:0);
                if (v == key[x]) { splay(x); return ans+1; }
                ans += cnt[x]; x = ch[x][1];
            }
        }
    }
    int kth(int v) {
        int x = root;
        while (true) {
            if (!x) return -1;
            if (ch[x][0] && v <= size[ch[x][0]]) x = ch[x][0];
            else {
                int t = (ch[x][0]?size[ch[x][0]]:0)+cnt[x];
                if (v <= t) { splay(x); return x; }
                v -= t; x = ch[x][1];
            }
        }
    }
    int pre(){
        int x=ch[root][0];
        while (ch[x][1]) x=ch[x][1];
        return x;
    }
    int next(){
        int x=ch[root][1];
        while (ch[x][0]) x=ch[x][0];
        return x;
    }
    long long solve(int v) {
        int x = kth(v);
        long long t1 = ch[x][0] ? key[x] * size[ch[x][0]] - sum[ch[x][0]] : 0;
        long long t2 = ch[x][1] ? sum[ch[x][1]] - key[x] * size[ch[x][1]] : 0;
        return t1 + t2;
    }
}ac;
int n, x, h[maxn];
int main() {
    //freopen("1.in", "r", stdin);
    scanf("%d%d", &n, &x);
    for (int i = 1; i <= n; i++) scanf("%d", &h[i]);
    long long ans = INF;
    ac.clear();
    for (int i = 1; i <= x; i++) ac.insert(h[i]);
    if (x & 1) ans = min(ans, ac.solve(x/2+1));
    else ans = min(ans, min(ac.solve(x/2), ac.solve(x/2+1)));
    for (int i = 2; i + x - 1 <= n; i++) {
        ac.delet(h[i-1]);
        ac.insert(h[i+x-1]);
        if (x & 1) ans = min(ans, ac.solve(x/2+1));
        else ans = min(ans, min(ac.solve(x/2), ac.solve(x/2+1)));
    }
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/hao_zong_yin/article/details/80116246