[luogu3157][bzoj3295][CQOI2011]动态逆序对【cdq分治+树状数组】

题目描述

对于序列A,它的逆序对数定义为满足i<j,且Ai>Aj的数对(i,j)的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。

分析

关于cdq分治第一篇学习笔记可以戳一下右边:【传送门】
简单的cdq分治,我不会树套树,所以就用cdq分治来做一下。
很明显的是,有答案贡献的都是\(time[i]<time[j]\)\(val[i]<val[j]\)\(pos[i]>pos[j]\)
以及\(time[i]<time[j]\)以及\(val[i]>val[j]\)\(pos[i]<pos[j]\),那么三维偏序就可以解决了。

ac代码

#include <bits/stdc++.h>
#define ll long long
#define N 100005
using namespace std;
struct BIT{
    #define lowbit(x) (x&-x)
    int n, tr[N];
    void add(int x, int val) {
        for (; x <= n; x += lowbit(x)) tr[x] += val;
    }
    int query(int x) {
        int res = 0;
        for (; x; x -= lowbit(x)) res += tr[x];
        return res;
    }
}tr;
struct Que {
    int cnt, v, d, id, t;
}q[N << 1];
int n, m;
ll ans[N];
int a[N], pos[N];
bool cmp(const Que &a, const Que &b) {
    return a.d < b.d;
}
void cdq(int l, int r) {
    if (l == r) return;
    int mid = (l + r) >> 1;
    cdq(l, mid); 
    cdq(mid + 1, r);
    sort(q + l, q + mid + 1, cmp);
    sort(q + mid + 1, q + 1 + r, cmp);
    int l1 = l, l2 = mid + 1;
    while (l2 <= r) {
        while (l1 <= mid && q[l1].d <= q[l2].d) tr.add(q[l1].v, q[l1].cnt), ++ l1;
        ans[q[l2].id] += q[l2].cnt * (tr.query(n) - tr.query(q[l2].v));
        l2 ++;
    }
    for (int i = l; i < l1; i ++) tr.add(q[i].v, -q[i].cnt);
    l1 = r; l2 = mid;
    while (l1 > mid) {
        while (l2 >= l && q[l2].d >= q[l1].d) tr.add(q[l2].v, q[l2].cnt), -- l2;
        ans[q[l1].id] += q[l1].cnt * tr.query(q[l1].v - 1);
        l1 --;
    } 
    for (int i = mid; i > l2; i --) tr.add(q[i].v, -q[i].cnt);
}
int main() {
    scanf("%d%d", &n, &m);
    int tot = 0;
    tr.n = n;
    for (int i = 1; i <= n; i ++) {
        scanf("%d", &a[i]);
        pos[a[i]] = i;
        q[++ tot] = (Que){1, a[i], i, 0, tot};
    }
    for (int i = 1; i <= m; i ++) {
        int x; scanf("%d", &x);
        q[++ tot] = (Que){-1, x, pos[x], i, tot};
    }
    cdq(1, tot);
    for (int i = 1; i <= m; i ++) ans[i] += ans[i - 1];
    for (int i = 0; i < m; i ++) printf("%lld\n", ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/chhokmah/p/10575413.html