「WC2018」州区划分-FWT+状压DP

Description

链接

Solution

首先我们可以轻松预处理出满足条件的点集。

f s f_s 表示点集 s s 的点组成州的方案数的满意度之和, s u m s sum_s 表示点集 s s 的人口的 p p 次方,当 s s 为合法州区时, g s = s u m s g_s=sum_s ,否则 g s = 0 g_s=0

那么,

f S = T S , T g T × f S T s u m S f_S=\sum_{T \subseteq S, T \neq \emptyset} \frac {g_T \times f_{S-T}} {sum_S}

这个问题可以用集合卷积优化,时间复杂度 O ( 2 n n 2 ) O(2^nn^2)

集合卷积

f S = A B = S , A B = g A × h B f_S=\sum_{A\cap B=S,A \cup B = \emptyset}g_A \times h_B

上式等价与 f S = A B = S , A + B = S g A × h B f_S=\sum_{A\cap B=S,|A|+|B|=|S|}g_A \times h_B

所以多设一维状态 f i , s f_{i,s} 表示集合为 s s ,集合大小为 i i 。当 S i |S| \neq i 时, f i , s f_{i,s} 0 0

类似的,定义 h i , s , g i , s h_{i,s},g_{i,s}

所以有 f i , s = A B = S , j + k = i g j , A × h k , B f_{i,s}=\sum_{A\cap B=S,j+k=i}g_{j,A} \times h_{k,B} 。用 F W T FWT 解决即可。

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;
const int maxn = 25, maxm = maxn * maxn / 2, mod = 998244353;

int n, m, p, U;
int u[maxm], v[maxm], w[maxn], deg[maxn];
int sum[1 << 21], g[maxn][1 << 21], f[maxn][1 << 21], bcnt[1 << 21], inv[2100 * 2100 + 5];
int fa[maxn];

inline int gi()
{
	char c = getchar();
	while (c < '0' || c > '9') c = getchar();
	int sum = 0;
	while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
	return sum;
}

inline void inc(int &a, int b) {a += b; if (a >= mod) a -= mod;}
inline void dec(int &a, int b) {a -= b; if (a < 0) a += mod;}

inline int find(int x)
{
	if (fa[x] == x) return x;
	return fa[x] = find(fa[x]);
}

void FWT(int *a, int n)
{
	for (int i = 1; i < n; i <<= 1)
		for (int j = 0; j < n; j += (i << 1))
			for (int k = 0; k < i; ++k) inc(a[i + j + k], a[j + k]);
}

void UFWT(int *a, int n)
{
	for (int i = 1; i < n; i <<= 1)
		for (int j = 0; j < n; j += (i << 1))
			for (int k = 0; k < i; ++k) dec(a[i + j + k], a[j + k]);
}

int main()
{
	n = gi(); m = gi(); p = gi();
	
	for (int i = 1; i <= m; ++i) u[i] = gi(), v[i] = gi();
	for (int i = 1; i <= n; ++i) w[i] = gi();

	inv[1] = 1;
	for (int i = 2; i <= 2100 * 2100; ++i) inv[i] = (lint)(mod - mod / i) * inv[mod % i] % mod;
	for (int i = 1; i < (1 << n); ++i) bcnt[i] = bcnt[i ^ (i & (-i))] + 1;
	
	for (int i = 1; i < (1 << n); ++i) {
		for (int j = 1; j <= n; ++j) {
			if ((i >> (j - 1)) & 1) sum[i] += w[j];
			deg[j] = 0; fa[j] = j;
		}
		
		for (int j = 1; j <= m; ++j)
			if (((i >> (u[j] - 1)) & 1) && ((i >> (v[j] - 1)) & 1)) ++deg[u[j]], ++deg[v[j]], fa[find(u[j])] = find(v[j]);
		int flag = 0, root = 0;
		for (int j = 1; j <= n; ++j)
			if ((i >> (j - 1)) & 1) {
				if (deg[j] & 1) {flag = 1; break;}
				if (!root) root = find(j);
				else if (find(j) != root) {flag = 1; break;}
			}

		if (p == 0) sum[i] = 1;
		else if (p == 2) sum[i] = (lint)sum[i] * sum[i] % mod;
		if (flag) g[bcnt[i]][i] = sum[i];
		sum[i] = inv[sum[i]];
	}

	U = 1 << n;
	for (int i = 0; i <= n; ++i) FWT(g[i], U);

	f[0][0] = 1;
	FWT(f[0], U);
	for (int i = 1; i <= n; ++i) {
		for (int j = 0; j < i; ++j)
			for (int s = 0; s < U; ++s)
				inc(f[i][s], (lint)f[j][s] * g[i - j][s] % mod);
		UFWT(f[i], U);
		for (int s = 0; s < U; ++s)
			if (bcnt[s] != i) f[i][s] = 0;
			else f[i][s] = (lint)f[i][s] * sum[s] % mod;
		if (i != n) FWT(f[i], U);
	}

	printf("%d\n", f[n][U - 1]);
	
	return 0;
}

猜你喜欢

转载自blog.csdn.net/DSL_HN_2002/article/details/84989744
今日推荐