BZOJ4071 [APIO2015]巴邻旁之桥

Address

Solution

  • 先把住宅和办公室在同侧的居民处理掉。
  • 注意到 k 2 ,进行分类讨论:

k = 1

  • 记选择桥的位置为 x ,则答案 = n + i = 1 n ( | s i x | + | t i x | )
  • 很容易发现, x 就是把所有 s i , t i 取出来排序后的中位数。
  • 于是就可以直接做了,时间复杂度 O ( n log n )

k = 2

  • 先介绍一个性质:

    当有多座桥存在时,对于第 i 个居民,选择位置距离 s i + t i 2 最近的桥最优。

  • 这与上面选择中位数最优是一个道理。

  • 因此把 n 个居民按照 s i + t i 2 从小到大排序。
  • 可知必定存在一个分割点,使得该位置左边的居民都会走第一座桥,该位置右边的居民都会走第二座桥。
  • 于是我们枚举这个分割点,现在只需做到每次能够快速计算出分割点两边的距离之和。
  • 距离之和与中位数有关,因此可以用 S p l a y 来维护。
  • 建立两棵 S p l a y ,每次把作为分割点的居民 i s i , t i 从一棵 S p l a y 中删除,再加入另一棵 S p l a y
  • S p l a y 只需维护子树大小和权值和,时间复杂度 O ( n log n )

  • 其实用权值线段树写常数就小很多了。

Code

  • 注意 l o n g   l o n g S p l a y 为空的情况。
  • k = 1 的情况也偷懒用现成的 Splay 写了。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cctype>

using namespace std;

namespace inout
{
    const int S = 1 << 20;
    char frd[S], *ihed = frd + S;
    const char *ital = ihed;

    inline char inChar()
    {
        if (ihed == ital)
            fread(frd, 1, S, stdin), ihed = frd;
        return *ihed++;
    }

    inline int get()
    {
        char ch; int res = 0; bool flag = false;
        while (!isdigit(ch = inChar()) && ch != '-');
        (ch == '-' ? flag = true : res = ch ^ 48);
        while (isdigit(ch = inChar()))
            res = res * 10 + ch - 48;
        return flag ? -res : res; 
    }

    inline char getChar()
    {
        char ch;
        while (ch = inChar(), ch != 'A' && ch != 'B');
        return ch;
    }
};
using namespace inout;

typedef long long ll;
const int N = 2e5 + 5;

struct splay_tree
{
    int lc[N], rc[N], fa[N], sze[N], val[N];
    int rt, T; ll sum[N];

    inline void Clear()
    {
        rt = T = 0;
        memset(lc, 0, sizeof(lc));
        memset(rc, 0, sizeof(rc));
        memset(fa, 0, sizeof(fa));
        memset(sze, 0, sizeof(sze)); 
    }
    inline void Uptdate(int x)
    {
        sum[x] = sum[lc[x]] + sum[rc[x]] + val[x];
        sze[x] = sze[lc[x]] + sze[rc[x]] + 1;
    }

    inline bool Which(int x)
    {
        return lc[fa[x]] == x;
    }

    inline void Rotate(int x)
    {
        int y = fa[x], z = fa[y];
        int b = lc[y] == x ? rc[x] : lc[x];
        fa[y] = x; fa[x] = z;
        if (b) fa[b] = y;
        if (z) (lc[z] == y ? lc[z] : rc[z]) = x;
        if (lc[y] == x) rc[x] = y, lc[y] = b;
            else lc[x] = y, rc[y] = b;
        Uptdate(y);
    }

    inline void Splay(int x, int tar)
    {
        while (fa[x] != tar)
        {
            if (fa[fa[x]] != tar)
                Which(fa[x]) == Which(x) ? Rotate(fa[x]) : Rotate(x);
            Rotate(x);
        }
        Uptdate(x);
        if (!tar) rt = x;
    }

    inline int getKth(int k)
    {
        int x = rt;
        while (x)
        {
            if (k <= sze[lc[x]])
                x = lc[x];
            else 
            {
                k -= sze[lc[x]] + 1;
                if (!k) return x;
                x = rc[x];
            } 
        }
    }

    inline void Insert(int v)
    {
        int x = rt, y = 0, dir;
        while (x)
        {
            ++sze[y = x]; sum[y] += v;
            if (v <= val[x]) x = lc[x], dir = 0;
                else x = rc[x], dir = 1;
        }
        fa[x = ++T] = y; 
        sum[x] = val[x] = v; sze[x] = 1;
        if (y) (dir == 0 ? lc[y] : rc[y]) = x;
        Splay(x, 0);
    }

    inline void Join(int x, int y)
    {
        lc[fa[x]] = rc[fa[y]] = 0; 
        int w = x;
        while (rc[w]) w = rc[w];
        fa[y] = w; rc[w] = y; fa[rt = x] = 0;
        Splay(w, 0); 
    }

    inline int Find(int v)
    {
        int x = rt;
        while (x)
        {
            if (v == val[x]) return x;
            if (v < val[x]) x = lc[x];
                else x = rc[x];
        }
        return x;
    }

    inline void Delete(int v)
    {
        int x = Find(v); Splay(x, 0);
        if (!lc[x] || !rc[x])
        { 
            fa[rt = lc[x] + rc[x]] = 0;
            lc[x] = rc[x] = 0;
        } 
        else 
            Join(lc[x], rc[x]);
    }

    inline ll Query()
    {
        int x = getKth(sze[rt] >> 1);
        Splay(x, 0);
        return (ll)sze[lc[x]] * val[x] - sum[lc[x]]
             + sum[rc[x]] - (ll)sze[rc[x]] * val[x];
    }
}a, b;

int K, n, m; ll Ans;

struct point
{
    int l, r, d;

    point() {}
    point(int L, int R, int D):
        l(L), r(R), d(D) {}

    inline bool operator < (const point &x) const 
    {
        return d < x.d; 
    }
}p[N];

inline ll Abs(ll x) {return x < 0 ? -x : x;}
inline void CkMin(ll &x, ll y) {if (x > y) x = y;}

int main()
{
    K = get(); n = get(); a.Clear(); b.Clear();
    char cl, cr; int dl, dr;
    for (int i = 1; i <= n; ++i)
    {
        cl = getChar(); dl = get();
        cr = getChar(); dr = get();
        if (cl == cr) Ans += Abs(dl - dr);
            else p[++m] = point(dl, dr, dl + dr), ++Ans;
    }
    sort(p + 1, p + m + 1);
    for (int i = 1; i <= m; ++i)
        b.Insert(p[i].l), b.Insert(p[i].r);

    if (K == 1)
    {
        if (m > 0) Ans += b.Query();
        cout << Ans << endl;
    } 
    else 
    {
        if (m > 0)
        {
            ll res = b.Query(); 
            for (int i = 1; i < m; ++i)
            {
                b.Delete(p[i].l); b.Delete(p[i].r);
                a.Insert(p[i].l); a.Insert(p[i].r);
                CkMin(res, a.Query() + b.Query());
            }
            Ans += res;
        }
        cout << Ans << endl;
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/bzjr_log_x/article/details/79837146