Codeforces 1209E2 Rotate Columns (hard version) dp

可以发现最多只有n个列是有用的, 然后状压dp一下就好了。

#include<bits/stdc++.h>
using namespace std;

const int inf = 0x3f3f3f3f;

int n, m, a[12][2007], b[12][12];
int sum[12][1 << 12];
int dp[2][1 << 12];

int *f = dp[0];
int *g = dp[1];

int mx[2007], id[2007];

inline int getNext(int mask) {
    int lastbit = mask & 1;
    mask >>= 1;
    mask |= lastbit << (n - 1);
    return mask;
}

int main() {
    int T; scanf("%d", &T);
    while(T--) {
        scanf("%d%d", &n, &m);
        for(int i = 0; i < m; i++) mx[i] = 0;
        for(int i = 0; i < n; i++) {
            for(int j = 0; j < m; j++) {
                scanf("%d", &a[i][j]);
                mx[j] = max(mx[j], a[i][j]);
            }
        }
        for(int i = 0; i < m; i++) id[i] = i;
        sort(id, id + m, [&](int x, int y) {
            return mx[x] > mx[y];
        });
        for(int o = 0; o < n & o < m; o++) {
            int x = id[o];
            for(int j = 0; j < n; j++) b[o][j] = a[j][x];
        }
        for(int i = 0; i < min(n, m); i++) {
            for(int mask = 0; mask < (1 << n); mask++) {
                sum[i][mask] = 0;
                for(int j = 0; j < n; j++) {
                    if(mask >> j & 1) {
                        sum[i][mask] += b[i][j];
                    }
                }
            }
        }
        for(int i = 0; i < (1 << n); i++) f[i] = 0;
        for(int i = 0; i < min(n, m); i++) {
            swap(f, g);
            for(int mask = 0; mask < (1 << n); mask++) {
                f[mask] = g[mask];
                for(int smask = mask; ; smask = (smask - 1) & mask) {
                    f[mask] = max(f[mask], g[smask] + sum[i][mask ^ smask]);
                    if(smask == 0) break;
                }
            }
            for(int mask = 0; mask < (1 << n); mask++) {
                int cur = mask;
                for(int i = 1; i < n; i++) {
                    cur = getNext(cur);
                    f[cur] = max(f[cur], f[mask]);
                }
            }
        }
        printf("%d\n", f[(1 << n) - 1]);
    }
    return 0;
}

/*
*/

猜你喜欢

转载自www.cnblogs.com/CJLHY/p/11734075.html