[国家集训队] 矩阵乘法

–> 洛谷 P1527 <–

这是一道整体二分的经典题目。

这道题显然可以给每个询问二分答案,统计该询问矩阵中小于等于mid的元素个数。如果大于等于k,说明猜大了,否则说明猜小了。

如果用这种方法的话,对于每个询问都至少要用O(询问矩阵大小*log值域)的时间复杂度解决,多组询问的话时间不能接受。

发现多个询问的二分答案是可以同时被检验的,我们可以为所有询问同时二分答案,把所有答案小于等于mid的询问放在询问序列的左侧,大于mid的放到询问序列的右侧然后递归处理。

这样为什么会快呢?我们每次可以用把矩阵中小于等于mid的元素染成黑色,剩下的元素保持白色。这样对于每一个询问的检验,就相当于是统计某个子矩阵中黑点的个数。为什么不用二维树状数组维护前缀和呢?

最开始的时候整个矩阵都设为0,然后让所有小于等于mid的位置加1, O ( N 2 log 2 N ) 处理完整个矩阵之后再 O ( l o g 2 N ) 去应付每一个询问。被询问的子矩阵中可能有很多重叠的部分,这样保证了在每次询问过程中,矩阵中的每个元素只对运行时间做一次贡献,所以这个算法要比对于每个询问单独二分要快得多。

这样做的时间复杂度为什么是对的呢?我们是在二分值域,考虑二分过程中的每一层对答案的贡献。

  1. 对于每一层二分,矩阵中的每个元素最多被加入树状数组一次。

  2. 对于每一层二分,每个询问只会被处理一次。

  3. 二分值域的过程中最多只会出现 O ( log N ) 层。

时间复杂度 O ( ( N 2 + Q ) log 3 N ) ,感觉这时间复杂度不开O2过不了啊,想要不开O2 AC可能得大力卡常一发。

发表一下个人见解:感觉CDQ分治和整体二分有异曲同工之妙,是同一种分治思想在不同维度发挥作用的体现。(之所以这么说是因为代码写起来是非常像的。)

还有一点,不要真的去二分值域,不然还得需要离散化,常数更大。把矩阵中的所有元素按照权值排序,二分在答案在排序之后的数组中的位置即可。(因为答案一定是矩阵中存在的值,所以在这些值里面二分就好了。)

给出代码:

// luogu-judger-enable-o2
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
using namespace std;

inline int geti() {
    int ans = 0; bool flag = 0; char c = getchar();
    while(!isdigit(c)) flag |= c == '-', c = getchar();
    while( isdigit(c)) ans=ans*10+c-'0', c = getchar();
    return flag? -ans: ans;
}

inline void puti(int x) {
    if(x < 0) x=-x, putchar('-');
    if(x > 9) puti(x / 10);
    putchar('0' + x%10);
}

const int maxn = 500 + 10, maxq = 60000 + 5;

struct Numbers {
    int x, y, v; /// (x,y) 位置的值为 v 
    bool operator < (const Numbers& rhs) const {
        return v < rhs.v; /// 用于比较大小 
    }
} Matrix[maxn * maxn]; int mcnt; /// 矩阵中的元素个数

struct BIT {
    int n; /// n * n 的二维树状数组
    int C[maxn][maxn];
    BIT(int N = 0) {n = N; memset(C, 0x00, sizeof(C));} /// 初始化
    inline int lowbit(int x) {return x&(-x);}
    inline void add(int x, int y, int v) { /// 单点修改 
        for(register int i = x; i <= n; i += lowbit(i))
            for(register int j = y; j <= n; j += lowbit(j))
                C[i][j] += v;
    }
    inline int pre(int x, int y) { /// 前缀求和 
        int ans = 0;
        for(register int i = x; i > 0; i -= lowbit(i))
            for(register int j = y; j > 0; j -= lowbit(j))
                ans += C[i][j];
        return ans;
    }
    inline int submat(int x1, int y1, int x2, int y2) { /// 子矩阵和 
        int ans = pre(x2, y2);
        ans -= pre(x1 - 1, y2) + pre(x2, y1 - 1);
        ans += pre(x1 - 1, y1 - 1);
        return ans;
    }
} bit; /// 记得初始化 n

struct Events { /// 记录所有询问 
    int x1, y1, x2, y2, k;
    inline void input() { /// 输入一个询问 
        x1 = geti(); y1 = geti();
        x2 = geti(); y2 = geti();
        k  = geti();
    }
} Querys[maxq];

int bcount(Events mat) { /// 查询某个询问中的黑点个数 
    return bit.submat(mat.x1, mat.y1, mat.x2, mat.y2);
}

int id[maxq], t1[maxq], t2[maxq];
int ans[maxq], cur[maxq];

void Sol(int l, int r, int ql, int qr) {
    if(qr < ql) return; /// 该值域区间没有询问
    if(l == r) {
        for(int i = ql; i <= qr; i ++) ans[id[i]] = Matrix[l].v;
        return; /// 找到解 
    }
    int mid = (l + r)/2;
    for(int i = l; i <= mid; i ++) /// 把要统计到答案中的数值染黑 
        bit.add(Matrix[i].x, Matrix[i].y, 1);
    int cnt1 = 0, cnt2 = 0;
    for(int i = ql; i <= qr; i ++) {
        int u = id[i]; /// 当前要处理的询问
        int s = cur[u] + bcount(Querys[u]); /// 考虑当前区间中的黑点个数
        if(s >= Querys[u].k) t1[++ cnt1] = u;
        else t2[++ cnt2] = u, cur[u] = s; 
    }
    int qcnt = ql - 1;
    for(int i = 1; i <= cnt1; i ++) id[++ qcnt] = t1[i]; /// 左右分组 
    for(int i = 1; i <= cnt2; i ++) id[++ qcnt] = t2[i];
    for(int i = l; i <= mid; i ++) /// 谁污染谁治理 
        bit.add(Matrix[i].x, Matrix[i].y, -1);
    Sol(l, mid, ql, ql + cnt1 - 1);
    Sol(mid+1, r, ql + cnt1, qr);
}

int main() {
    int N = geti(), Q = geti(); bit.n = N;/// 输入矩阵大小和询问组数
    for(int i = 1; i <= N; i ++)
        for(int j = 1; j <= N; j ++)
            Matrix[++ mcnt] = (Numbers){i, j, geti()};
    sort(Matrix + 1, Matrix + mcnt + 1); /// 按元素大小从小到大排序
    for(int i = 1; i <= Q; i ++) Querys[i].input();
    for(int i = 1; i <= Q; i ++) id[i] = i;
    Sol(1, mcnt, 1, Q);
    for(int i = 1; i <= Q; i ++) puti(ans[i]), putchar('\n');
    return 0;
}

猜你喜欢

转载自blog.csdn.net/ggn_2015/article/details/79766792