HDU 6851:Heart(dp + 子集卷积优化)

在这里插入图片描述


数据中不会有 x = 0 , b i = 0 x = 0,b_i=0 x=0,bi=0 的数据
显然可以令 d p [ i ] dp[i] dp[i] 表示使用的碎片集合为 i i i 的贡献,转移方程为 d p [ i ] = ∑ j & k = 0 , j ∣ k = i d p [ j ] ∗ s u m [ k ] dp[i]=\sum_{j \& k=0,j|k=i}dp[j]*sum[k] dp[i]=j&k=0,jk=idp[j]sum[k]
其中 s u m [ j ] sum[j] sum[j] 表示 b = j b = j b=j 的碎片的 danmakus 总和, d p [ 0 ] = 1 dp[0] = 1 dp[0]=1

注意到这样会有重复,按最高位来枚举 k,即 k k k 的最高位和 i i i 的最高位相同,即可去重。

注意到这是一个类似卷积的形式,但是他是子集卷积。

子集卷积满足: i & j = 0 , i ∣ j = k i \&j = 0,i | j=k i&j=0,ij=k直接用 FWT 做或卷积是不行的,设 f ( i ) f(i) f(i) 表示 i i i 的二进制表示下 1 的个数,那么子集卷积也可以转化为 : f ( i ) + f ( j ) = f ( k ) , i ∣ j = k f(i) + f(j) = f(k),i | j = k f(i)+f(j)=f(k),ij=k

转移方程为: d p [ x ] [ i ] = ∑ y = 0 x ∑ j ∣ k = i d p [ y ] [ j ] ∗ s u m [ x − y ] [ k ] dp[x][i]=\sum_{y=0}^x\sum_{j|k=i}dp[y][j]*sum[x-y][k] dp[x][i]=y=0xjk=idp[y][j]sum[xy][k]

由于要枚举最高位和 i i i 相同的 danmakus 进行转移,不妨按最高位进行分类,对每一类进行转移时,由于 FWT 和 IFWT 都有可加性,将所有的 d p [ x ] dp[x] dp[x] s u m [ x ] sum[x] sum[x] 进行 FWT 运算后再转移,最后再全部进行逆运算,可以在 n 2 ∗ 2 n n^2*2^n n22n 的时间内计算子集卷积。


代码:

#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
const int maxn = 2100000;
const int mod = 998244353;
int n, m, p[maxn], b[maxn], bin[maxn], up[maxn], val[23][maxn];
int dp[23][maxn], tp[maxn];
vector<pii> g[23];
void fwt(int a[],int len) {
    
    
	for(int s = 2; s <= len; s <<= 1) {
    
    
		for(int j = 0; j < len; j += s) {
    
    
			for(int k = 0; k < s / 2; k++) {
    
    
				int x = a[j + k];
				int y = a[j + k + s / 2];
				a[j + k] = x, a[j + k + s / 2] = (x + y) % mod;
				//对不同的卷积不同的变换 
				//xor:a[j+k] = x+y, a[j+k+s/2] = x-y;
                //and:a[j+k] = x+y, a[j+k+s/2] = y;
                //or:a[j+k] = x, a[j+k+s/2] = x+y;
			}
		}
	}
}
void ufwt(int a[],int len) {
    
    
	for(int s = 2; s <= len; s <<= 1) {
    
    
		for(int i = 0; i < len; i += s) {
    
    
			for(int j = 0; j < s / 2; j++) {
    
    
				int x = a[i + j], y = a[i + j + s / 2];
				a[i + j] = x, a[i + j + s / 2] = (y - x + mod) % mod;
				//xor: a[i+j] = (x + y) / 2, a[i+j+s/2] = (x - y) / 2;
				//and: a[i+j] = x - y, a[i+j+s/2] = y;
				//or: a[i+j] = x, a[i+j+s/2] = y - x;
			}
		}
	}	
}
void solve(int i) {
    
    			//i:1 - 21,最高位在第 (i - 1) 位 
	int n = i, len = (1 << i);
	for (int j = 0; j <= n; j++)
		for (int k = 0; k <= len; k++)
			val[j][k] = 0;
	val[0][0] = 1;	
	for (auto it : g[i]) {
    
    
		val[bin[it.fir]][it.fir] += it.sec;
		if (val[bin[it.fir]][it.fir] >= mod) 
			val[bin[it.fir]][it.fir] -= mod;
	}
	dp[0][0] = 1;
	for (int j = 0; j <= n; j++) {
    
    
		fwt(val[j],len);
		fwt(dp[j],len);
	}
	for (int j = n; j >= 0; j--) {
    
    
		for (int k = 0; k <= j; k++) {
    
    
			for (int v = 0; v < len; v++)
				tp[v] = (tp[v] + 1ll * dp[k][v] * val[j - k][v] % mod) % mod;
		}
		for (int v = 0; v < len; v++)
			dp[j][v] = tp[v], tp[v] = 0;
		ufwt(dp[j],len);
	}
}
int main() {
    
    
	bin[0] = 0; up[0] = 0;
	for (int i = 1; i < (1 << 21); i++)
		bin[i] = bin[i >> 1] + (i & 1), up[i] = up[i / 2] + 1; 
	scanf("%d",&n);
	int highest = 0;
	for (int i = 1; i <= n; i++) {
    
    
		int p, b;
		scanf("%d%d",&p,&b);
		g[up[b]].push_back(pii(b,p));
		highest = max(up[b],highest);
	}
	for (int i = 1; i <= highest; i++)
		solve(i);
	scanf("%d",&m);
	while (m--) {
    
    
		int x, ans = 0; scanf("%d",&x);
		printf("%d\n",dp[bin[x]][x]);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/108620799
今日推荐