UVA 11990 `Dynamic'' Inversion CDQ分治, 归并排序, 树状数组, 尺取法, 三偏序统计 难度: 2

题目

https://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3141


题意

一个1到n的排列,每次随机删除一个,问删除前的逆序数

思路

综合考虑,对每个数点,令value为值,pos为位置,time为出现时间(总时间-消失时间),明显是统计value1 > value2, pos1 < pos2, time1 < time2的个数

首先对其中一个轴排序,比如value,这样在归并过程中,左子树的value总是小于右子树的,可以分治。

当左右子树包含哪些数点已经确定后,可以用自下而上的归并排序使得子树上的数点按照第二维相对有序,方便用尺取法统计子树之间的逆序数。

第三维通过树状数组进行压缩,加快统计速度。

注意仅仅统计左子树对右子树的影响,就会错过右子树中的数点出现的比较晚的情况。因此需要统计右子树对左子树的影响,此时注意别把同一时间出现的重复计数。

感想

1. 注意long long!!!

2. BIT的上限要>=n!

3. 注意统计影响完成后需要清空树状数组(区间大小已经减少了所以可以浪费地使用),此时不能直接用memset清空整个数组,时间会成为O(n2),超时。

代码

时间: 0.250s 

时间复杂度O(cnlogn)

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <queue>
#include <tuple>
#include <cassert>

using namespace std;

const int MAXN = int(4e5 + 4);

#define LEFT_CHILD(x) ((x) << 1)
#define RIGHT_CHILD(x) (((x) << 1) + 1)
#define FATHER(x) ((x) >> 1)
#define IS_LEFT_CHILD(x) (((x) & 1) == 0)
#define IS_RIGHT_CHILD(x) (((x) & 1) == 1)
#define BROTHER(x) ((x) ^ 1)
#define LOWBIT(x) ((x) & (-x))

#define LOCAL_DEBUG 

struct Node{
    int value, pos, time;
}nodes[MAXN], tmpNodes[MAXN];

int timeCnt[MAXN * 4];
long long revNum[MAXN];
int clearStack[MAXN];
int clearLen;
int n, m;
int bitLimit;

int getHigherBit(int n) {
    int x = 1;
    while (x < n) { x <<= 1; }
    return x;
}

void update(int id) {
    while (id <= bitLimit) {
        if (timeCnt[id] == 0) {
            clearStack[clearLen++] = id;
        }
        timeCnt[id]++;
        id += LOWBIT(id);
    }
}

void clearCnt() {
    while (clearLen > 0) {
        timeCnt[clearStack[--clearLen]] = 0;
    }
}

int countTimesSmaller(int id) {
    if (id < 0)return 0;
    int sum = 0;
    int tmp = 0;
    while (id > 0) {
        sum += timeCnt[id];
        id -= LOWBIT(id);
    }
    return sum;
}

void merge_by_pos(int root_ind, int internal_l, int internal_r) {
    int internal_mid = (internal_l + internal_r) >> 1;
    for (int i = internal_l; i <= internal_r; i++) {
        tmpNodes[i] = nodes[i];
    }
    for (int i = internal_l, j = internal_mid + 1, ind = internal_l; ind <= internal_r; ) {
        if (i > internal_mid) {
            nodes[ind++] = tmpNodes[j++];
        }
        else if (j > internal_r) {
            nodes[ind++] = tmpNodes[i++];
        }
        else if (tmpNodes[i].pos < tmpNodes[j].pos) {
            nodes[ind++] = tmpNodes[i++];
        }
        else {
            nodes[ind++] = tmpNodes[j++];
        }
    }
}
void cal(int root_ind, int internal_l, int internal_r) {
    if (internal_l == internal_r)return;
    int internal_mid = (internal_l + internal_r) >> 1;
    if(internal_l != internal_mid)cal(LEFT_CHILD(root_ind), internal_l, internal_mid);
    if (internal_mid + 1 != internal_r)cal(RIGHT_CHILD(root_ind), internal_mid + 1, internal_r);
//    printf("L Node: %d[%d, %d] LC: %d[%d, %d], RC: %d[%d, %d]\n", root_ind, internal_l, internal_r, LEFT_CHILD(root_ind), internal_l, internal_mid, RIGHT_CHILD(root_ind), internal_mid + 1, internal_r);
    for (int i = internal_l, j = internal_mid + 1; i <= internal_mid; i++) {
        while (j <= internal_r && nodes[i].pos > nodes[j].pos) {
            update(nodes[j].time);
            j++;
        }
        revNum[nodes[i].time] += countTimesSmaller(nodes[i].time);
//        printf("L (%d, %d, %d): +%d\n", nodes[i].value, nodes[i].pos, nodes[i].time, countTimesSmaller(nodes[i].time));
    }
    clearCnt();

    for (int i = internal_mid, j = internal_r; j > internal_mid; j--) {
        while (i >= internal_l && nodes[i].pos > nodes[j].pos) {
            update(nodes[i].time);
            i--;
        }
        revNum[nodes[j].time] += countTimesSmaller(nodes[j].time - 1);
//        printf("R (%d, %d, %d): +%d\n", nodes[j].value, nodes[j].pos, nodes[j].time, countTimesSmaller(nodes[j].time - 1));
    }
    clearCnt();
    merge_by_pos(root_ind, internal_l, internal_r);

}

int main() {
#ifdef LOCAL_DEBUG
    freopen("input.txt", "r", stdin);
    freopen("output2.txt", "w", stdout);
#endif // LOCAL_DEBUG
    for (int ti = 1; scanf("%d%d", &n, &m) == 2; ti++) {
        bitLimit = getHigherBit(n);
        for (int i = 1; i <= n; i++) {
            int tmp;
            scanf("%d", &tmp);
            nodes[tmp].value = tmp;
            nodes[tmp].pos = i;
            nodes[tmp].time = 1;
        }
        for (int i = 1; i <= m + 1; i++) { revNum[i] = 0; }
        for (int i = 0; i < m; i++) {
            int tmp;
            scanf("%d", &tmp);
            nodes[tmp].time = m - i + 1;
        }
        cal(1, 1, n);
        long long ans = 0;
        for (int i = 1; i <= m + 1; i++) { ans += revNum[i]; }
        for (int i = 0; i < m; i++) {
            printf("%lld\n", ans);
            ans -= revNum[m - i + 1];
        }
    }
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/xuesu/p/10356068.html