gym/102576/B SOSDP

SOSDP顾名思义就是,救命啊是DP Sum over Subsets(SOS)DP
本文约等于https://codeforces.com/blog/entry/45223的机翻中文版

问题引入

给一个2^n长度的数组A,现在对于任意x要预处理出函数F(x)的返回值。
F(x)的定义:SUM(A[i] | x & i == i)
即 i 的二进制表示被x包含,F(x)返回所有满足条件 i 的A[ i ]总和。

解法

暴力 O(4^n)

//枚举每一种x
for(int mask = 0;mask < (1<<N); ++mask){
    
    
	//枚举每一个i
	for(int i = 0;i < (1<<N); ++i){
    
    
		if((mask&i) == i){
    
    
			F[mask] += A[i];
		}
	}
}

这种解法枚举了很多无用的i,对于给定的x,可以用更高效的方法枚举其包含的二进制数。


优化后的暴力O(3^n)

//依旧枚举所有的x
for (int mask = 0; mask < (1<<n); mask++){
    
    
	F[mask] = A[0];
    // 枚举了所有有效的i
    for(int i = mask; i > 0; i = (i-1) & mask){
    
    
    	F[mask] += A[i];
    }
}

复杂度证明:枚举i含有k个1位,则有C(n,k)种可能,每种可能有2^k种情况。
在这里插入图片描述
暴力的做法最大的浪费在于把x的二进制位混在一起枚举。


DP

要把枚举的x,按照二进制拆分。
设计dp状态:
dp[被数字x包含][且最右边y位和x相同] = 的下标贡献总和

初始状态:
dp[i][0] = A[i];

转移方程:

for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
	for(int j = 1 ; j < maxn ; j++){
    
    
		//如果j位是1,例如??1xx,则包含了??(1)xx,??(0)xx
		if(i & (1 << (j-1)))
			dp[i][j] = dp[i ^ (1 << (j-1))][j - 1] + dp[i][j - 1];
		else
			dp[i][j] = dp[i][j - 1];
	}
}

原作者在举例这个dp时下标到了-1,(因为只是为了过度并不打算实现)
下图为注释图,逗号右边的数字比我实现的小1

改进SOSDP

可以发现无论第j位的数字是0或者1,都会继承j - 1位的状态。利用滚动数组的思想(滚动的是位数,而不是枚举的数字x),可以把dp优化成1维的,且非常容易实现。

		//滚动数组优化而来
		for(int j = 1 ; j < maxn ; j++){
    
    
			for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
				if(i & (1 << (j-1)))dp[i] += dp[i ^ (1 << (j-1))];
			}
		}

思考:
因为滚动背包优化后,外循环是当前位数,内循环是枚举的数字。
数字应该从大到小枚举,不然就会成为完全背包,一种状态可能被统计多次,但是这里并没有这么做,为什么?



例题

https://codeforces.ml/gym/102576/problem/B

AC源码

改进前

#include<bits/stdc++.h>
using namespace std;
const int maxn = 21;

//dp[数字i][最右j位包含的] 数字个数(包括本身)
//x包含y, 指x & y = y,且最右边j位相同
long long dp[1 << maxn][maxn];
int a[1 << maxn];

int main(){
    
    
	int T;
	scanf("%d",&T);

	while(T--){
    
    
		for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
			for(int j = 0 ; j < maxn ; j++)dp[i][j] = 0;
		}

		int n;
		scanf("%d",&n);
		for(int i = 1 ; i <= n ; i++){
    
    
			scanf("%d", &a[i]);
			dp[a[i]][0]++;
		}

		for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
			for(int j = 1 ; j < maxn ; j++){
    
    
				//如果j位是1,例如??1xx,则包含了??(1)xx,??(0)xx
				//
				if(i & (1 << (j-1)))
					dp[i][j] = dp[i ^ (1 << (j-1))][j - 1] + dp[i][j - 1];
				else

					dp[i][j] = dp[i][j - 1];
			}
		}

		long long ans = 0;
		for(int i = 1 ; i <= n ; i++)ans += dp[a[i]][maxn - 1];
		printf("%lld\n",ans);

	}
}

改进后

#include<bits/stdc++.h>
using namespace std;
const int maxn = 21;

long long dp[1 << maxn];
int a[1 << maxn];

int main(){
    
    
	int T;
	scanf("%d",&T);

	while(T--){
    
    
		for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
			dp[i] = 0;
		}

		int n;
		scanf("%d",&n);
		for(int i = 1 ; i <= n ; i++){
    
    
			scanf("%d", &a[i]);
			dp[a[i]]++;
		}

		//滚动数组优化而来
		for(int j = 1 ; j < maxn ; j++){
    
    
			for(int i = 1 ; i < (1 << maxn) ; i++){
    
    
				if(i & (1 << (j-1)))dp[i] += dp[i ^ (1 << (j-1))];
			}
		}

		long long ans = 0;
		for(int i = 1 ; i <= n ; i++)ans += dp[a[i]];
		printf("%lld\n",ans);

	}
}

猜你喜欢

转载自blog.csdn.net/qq_35068676/article/details/112861776
今日推荐