[CF1209E2]Rotate Columns (hard version)

题目

传送门 to luogu

思路

初始思路

将所有元素拿出来排序,然后从大到小考虑,贪心。

但是有误。假如剩下的行必须隔着两行,就还是应该动态规划。

正解思路

既然是“最大值之和”,完全可以理解为,随意选择能够拿到的最大值。最优情况会自动地选择到每一行的最大值。

既然如此,用 f ( i , S ) f(i,S) 表示,处理了前 i i 列,哪些行已经选择了。然后我们继续优化。显然每一列只选择一个时,完全可以利用移位都选取到。所以,我们以每列的最大值为关键字,进行排序。接下来就只需要考虑前 n n 列了——此时的 d p \tt dp 值一定不小于这 n n 个最大值之和,后面的所有数字都不大于这 n n 个数,所以不可能继续更新了。

复杂度降至每组数据 O [ min ( n , m ) 3 n + m log n + n m ] \mathcal O[\min(n,m)\cdot 3^n+m\log n+nm]

代码

#include <cstdio>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
inline int readint(){
	int a = 0; char c = getchar(), f = 1;
	for(; c<'0'||c>'9'; c=getchar())
		if(c == '-') f = -f;
	for(; '0'<=c&&c<='9'; c=getchar())
		a = (a<<3)+(a<<1)+(c^48);
	return a*f;
}
template < class T >
void getMax(T&a,const T&b){if(a<b)a=b;}
template < class T >
void getMin(T&a,const T&b){if(b<a)a=b;}

const int MaxN = 12, MaxM = 2000;
int a[MaxN][MaxM+1];
int c[MaxM+1]; // 每一列的最大值
int n, m;
void input(){
	n = readint(), m = readint();
	for(int j=1; j<=m; ++j) c[j] = 0;
	for(int i=0; i<n; ++i)
		for(int j=1; j<=m; ++j){
			a[i][j] = readint();
			getMax(c[j],a[i][j]);
		}
}

struct CMP{
	bool operator()(const int &a,const int &b){
		return c[a] > c[b];
	}
};
priority_queue<int,vector<int>,CMP> pq;
int dp[MaxN+1][1<<MaxN], val[1<<MaxN];
bool vis[1<<MaxN]; // 用来降低不重要的复杂度
# define nxt (s>>1)^((s&1)<<n>>1)
void solve(){
	while(!pq.empty()) pq.pop();
	for(int i=1; i<=m&&i<=n; ++i)
		pq.push(i);
	for(int i=min(n,m)+1; i<=m; ++i)
		if(c[i] > c[pq.top()])
			pq.pop(), pq.push(i);
	m = min(n,m); // 只需要这么多
	for(int i=1; i<=m; ++i,pq.pop()){
		for(int j=0; j<n; ++j)
			val[1<<j] = a[j][pq.top()];
		for(int S=1; S<(1<<n); ++S)
			val[S] = val[S&-S]+val[S^(S&-S)];
		for(int S=1; S<(1<<n); ++S)
			vis[S] = false;
		for(int S=1; S<(1<<n); ++S)
		if(!vis[S])
		for(int j=0,s=S; j<n*2; ++j,s=nxt){
			vis[s] = true; // 将环标记
			getMax(val[nxt],val[s]);
		}
		for(int S=1; S<(1<<n); ++S){
			dp[i][S] = dp[i-1][S];
			for(int s=S; s; s=(s-1)&S)
				getMax(dp[i][S],dp[i-1][S^s]+val[s]);
		}
	}
	printf("%d\n",dp[m][(1<<n)-1]);
}

int main(){
	for(int T=readint(); T; --T)
		input(), solve();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_42101694/article/details/107886793