[GXOI/GZOI2019]与或和 题解

传送门

题意:给出一个 n × n n\times n 的矩阵,求所有子矩阵的 and \text{and} 值之和和 or \text{or} 值之和。

显然可以把每一位分开来求,那么对于某一位而言矩阵中的元素不是 0 0 就是 1 1 。一个仅由 0 0 1 1 构成的子矩阵的 and \text{and} 值是 1 1 当且仅当这个子矩阵中全是 1 1 ,否则是 0 0 ;类似地, or \text{or} 值是 0 0 当且仅当这个子矩阵中全是 0 0 ,否则是 1 1

所以问题转化为求一个 0 1 0-1 矩阵有多少个仅由 x x 构成的子矩阵,其中 x { 0 , 1 } x\in\{0,1\} 原题传送

x = 1 x=1 为例:首先求出 r ( i , j ) r(i,j) 表示第 i i 行第 j j 列这个位置开始往右最长连续 1 1 的个数,这个递推非常简单。然后考虑 f ( i , j ) f(i,j) 表示以 ( i , j ) (i,j) 为左下角的全 1 1 子矩阵的个数。我们从右往左一列一列地做。对于第 i i 列第 j j 个位置,应该有 f ( j , i ) = f ( p , i ) + ( j p ) × r ( j , i ) f(j,i)=f(p,i)+(j-p)\times r(j,i) ,其中 p p 是使得 r ( p , i ) < r ( j , i ) r(p,i)<r(j,i) 的最大的下标。这个 p p 显然可以用单调栈维护。所以这里的复杂度 O ( n 2 ) O(n^2) 。另外 f f 数组可以去掉一维,因为每一列都是独立的。

总复杂度 O ( n 2 log max { a i j } ) O(n^2\log\max\{a_{ij}\})

#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>

template <typename T> inline void read(T& x) {
    int f = 0, c = getchar(); x = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
    if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
    read(x); read(args...); 
}
template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

const int mod = 1e9 + 7;
const int maxn = 1007;

int a[maxn][maxn];
bool b[maxn][maxn];
int r[maxn][maxn];
int f[maxn], stk[maxn], tp;
int n, mx, as, os;

inline int count(bool x) {
    for (int i = 1; i <= n; ++i)
        for (int j = n; j; --j)
            r[i][j] = b[i][j] == x ? r[i][j + 1] + 1 : 0;
    int ans = 0;
    for (int i = n; i; --i) {
        stk[tp = 0] = 0;
        for (int j = 1; j <= n; ++j) {
            while (tp && r[stk[tp]][i] >= r[j][i]) --tp;
            int pos = stk[tp];
            stk[++tp] = j;
            f[j] = (f[pos] + 1ll * (j - pos) * r[j][i]) % mod;
            if ((ans += f[j]) >= mod) ans -= mod;
        }
    }
    return ans;
}

int main() {
    read(n);
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n; ++j)
            read(a[i][j]), chkmax(mx, a[i][j]);
    int all = n * (n + 1) >> 1;
    all = 1ll * all * all % mod;
    for (int w = 0; (1ll << w) <= mx; ++w) {
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= n; ++j)
                b[i][j] = a[i][j] & (1 << w);
        as = (as + 1ll * ((1ll << w) % mod) * count(1) % mod) % mod;
        os = (os + 1ll * ((1ll << w) % mod) * ((all - count(0) + mod) % mod) % mod) % mod;
    }
    write(as); putchar(' '); write(os);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39677783/article/details/89610132