【NOIP 2017 提高组】列队

题目

有一个 \(n\times m\) 的方阵,每次出来一个人后向左看齐,向前看齐,询问每次出来的人的编号。

\(n\le 3\times 10^5\)

分析

我们考虑离队本质上只有两种操作:

  • 删除
  • 放入末尾

发现这显然可以用平衡树处理,这里我选用Splay。

需要注意的是,空间不可能开到 \(9\times 10^{10}\),只能动态开点,发现队列总是有大部分是连续的编号,所以每个节点储存一个标号范围,若需要访问到中间的标号 \(k\),将该节点分裂成三个节点\([l,k - 1],\ [k, k],\ [k + 1, r]\)即可。

时间复杂度 \(\Theta(n\log n)\),空间复杂度 \(\Theta(n\log n)\)。可以通过。

代码

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 3e5 + 5;

ll n, m;

struct node {
    ll left, right, size;
    node *child[2], *father;
    
    void updata() {size = child[0]->size + child[1]->size + right - left + 1;}
} *nul = new node;

void init() {
    nul->child[0] = nul->child[1] = nul->father = nul;
    nul->left = 1; nul->right = nul->size = 0;
}

node *newNode(ll l, ll r, node *fa) {
    node *ptr = new node;
    ptr->father = fa; ptr->child[0] = nul; ptr->child[1] = nul;
    ptr->left = l; ptr->right = r; ptr->size = r - l + 1;
    return ptr;
}

void cect(node *x, node *y, ll p) {x->father = y; y->child[p] = x;}
bool getPos(node *x) {return x->father->child[1] == x;}

void rotate(node *x) {
    ll p = getPos(x), fp = getPos(x->father);
    node *gfa = x->father->father;
    cect(x->child[p ^ 1], x->father, p);
    cect(x->father, x, p ^ 1); cect(x, gfa, fp);
    x->child[p ^ 1]->updata(); x->updata();
}

void split(node *x, ll k) {
    node *tmpChild[2] = {x->child[0], x->child[1]};
    if(x->left < k) {
        x->child[0] = newNode(x->left, k - 1, x);
        if(tmpChild[0] != nul) cect(tmpChild[0], x->child[0], 0);
        x->child[0]->updata();
    }
    if(x->right > k) {
        x->child[1] = newNode(k + 1, x->right, x);
        if(tmpChild[1] != nul) cect(tmpChild[1], x->child[1], 1);
        x->child[1]->updata();
    }
    x->left = x->right = k;
}

struct splayTree {
    node *rt;

    splayTree() {rt = newNode(1, 0, nul);}

    void build(node *&cur, node *fa, ll l, ll r) {
        if(l <= r) {
            ll mid = (l + r) >> 1;
            cur = newNode(mid * m, mid * m, fa);
            build(cur->child[0], cur, l, mid - 1);
            build(cur->child[1], cur, mid + 1, r);
            cur->updata();
        }
    }

    void splay(node *cur, node *goal) {
        while(cur->father != goal) {
            node *fa = cur->father, *gfa = cur->father->father;
            if(gfa != goal) getPos(cur) == getPos(fa) ? rotate(fa) : rotate(cur);
            rotate(cur);
        }
    }

    node *findKth(ll k) {
        node *cur = rt->child[1];
        while(true) {
            if(k <= cur->child[0]->size) cur = cur->child[0];
            else if(k > cur->child[0]->size + cur->right - cur->left + 1) {
                k = k - cur->child[0]->size - (cur->right - cur->left + 1);
                cur = cur->child[1];
            } else {
                split(cur, cur->left + k - cur->child[0]->size - 1);
                splay(cur, rt);
                return cur;
            }
        }
    }

    ll earse(node *goal) {
        node *rplNode = goal->child[0];
        while(rplNode->child[1] != nul) rplNode = rplNode->child[1];
        if(rplNode != nul) {
            splay(rplNode, goal);
            cect(rplNode, rt, 1);
            cect(goal->child[1], rplNode, 1);
            rplNode->updata();
        } else cect(goal->child[1], rt, 1);
        ll id = goal->left;
        delete goal;
        return id;
    }

    void insLast(ll k) {
        node *cur = rt;
        while(cur->child[1] != nul) cur = cur->child[1];
        if(cur != rt) splay(cur, rt);
        cur->child[1] = newNode(k, k, cur);
        cur->size++;
    }
} row[MAXN], lastLine;

ll leave(ll x, ll y) {
    ll id;
    if(y != m) {
        lastLine.insLast(id = row[x].earse(row[x].findKth(y)));
        row[x].insLast(lastLine.earse(lastLine.findKth(x)));
    } else lastLine.insLast(id = lastLine.earse(lastLine.findKth(x)));
    return id;
}

int main() {
    init();
    ll q, x, y;
    scanf("%lld%lld%lld", &n, &m, &q);
    lastLine.build(lastLine.rt->child[1], lastLine.rt, 1, n);
    for(ll i = 1; i <= n; i++)
        (row[i].rt)->child[1] = newNode((i - 1) * m + 1, i * m - 1, row[i].rt);
    while(q--) {
        scanf("%lld%lld", &x, &y);
        printf("%lld\n", leave(x, y));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zhylj/p/9879595.html
今日推荐