Educational Codeforces Round 56 (Rated for Div. 2) E. Intersection of Permutations(分块 + 树状数组)

题目链接:https://codeforces.com/contest/1093/problem/E

题目大意:给出两个1~n的排列 a 和 b;对这两个排列进行如下两种操作:

1 la ra lb rb:查询排列 a 的区间 [la,ra] 与排列 b 的区间 [lb,rb]内有多少个相同的数;

2 x y:将b[x] 与 b[y]的值进行交换。

题目思路:·由于更新操作中只会对排列 b 进行修改,所以我们可以用一个pos数组来记录排列 a 中各个数的位置。

这样我们就可以知道排列 b 中的各个数在 a 中所对应的位置是哪里。

这样我们就可以用一个二维树状数组 bit[i][j] 来维护,bit[i][j] 表示在b[1] ~ b[i] 中与 a[1]~a[j]中有多少个相同的数。

一开始的时候可以就是对于每一个 b[i] 将 bit[i][pos[b[i]] ~ bit[n][n]加1,代表在这个区间内是有一个相同的数。

查询的时候就是 sum(ra,rb) - sum(la-1,rb) - sum(ra,lb-1) + sum(la-1,lb-1)。(sum(x,y) 就是一个正常的二维树状数组查询)

但现在由于n 最大可以达到 2e5,无法直接开一个bit[2e5][2e5]的数组,所以我们就考虑分块。

将b排列 分为sqrt(n)个部分,前面的更新是一个点一个点的更新,现在更新是一个部分一个部分的更新,更新一个点的时候也就是更新相应的块的内容。

这样就可以将数组的大小缩减为:bit[sqrt(2e5)][2e5],这就是一个可接受的范围了。

在查询的时候,就查询其所在块的情况就可以了,但还有一点要注意的就是,查询区间右端点所在块的信息的时候,要将这一块单独取出来查询计算(因为并不是将这一整块的信息查询,只需要到右端点部分的信息)。

分块之后的时间复杂度就大概是O(m*log(sqrt(n))*log(n))。6s的时间就绰绰有余了。

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define MP make_pair
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define clr(a) memset(a,0,sizeof(a))
#define _inf(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<#x<<" " << (x) << "]"<<endl
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
const int MX = 2e5 + 5;
const int inf = 0x3f3f3f3f;

int n, m;
int bsz, nsz, pos[MX], b[MX];
int block[MX];
int bit[500][MX];
inline int lowbit(int x) {return x & -x;}

inline void upd(int x, int y, int d) {
    int tmp = block[y] + 1;
    for (int i = tmp; i <= nsz; i += lowbit(i)) {
        for (int j = x; j <= n; j += lowbit(j))
            bit[i][j] += d;
    }
}
inline int ask(int x, int y) {
    int res = 0;
    int tmp = block[y];
    for (int i = max(tmp * bsz, 1); i <= y; i++) if (b[i] <= x) res++;
    for (int i = tmp; i; i -= lowbit(i)) {
        for (int j = x; j; j -= lowbit(j))
            res += bit[i][j];
    }
    return res;
}

int main() {
    // FIN;
    scanf("%d%d", &n, &m);
    bsz = (int)sqrt(n);
    nsz = n / bsz;
    for (int i = 1; i <= n; i++) {
        int x; scanf("%d", &x);
        pos[x] = i;
        block[i] = i / bsz;
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", &b[i]);
        b[i] = pos[b[i]];
        upd(b[i], i, 1);
    }
    int op, la, ra, lb, rb;
    while (m--) {
        scanf("%d", &op);
        if (op == 1) {
            scanf("%d%d%d%d", &la, &ra, &lb, &rb);
            int ans = ask(ra, rb) - ask(la - 1, rb) - ask(ra, lb - 1) + ask(la - 1, lb - 1);
            printf("%d\n", ans);
        } else {
            scanf("%d%d", &lb, &rb);
            upd(b[lb], lb, -1); upd(b[rb], rb, -1);
            swap(b[lb], b[rb]);
            upd(b[lb], lb, 1); upd(b[rb], rb, 1);
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Lee_w_j__/article/details/85092100