【省内训练2018-12-23】Counting

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/85224695

【思路要点】

  • 问题等价于求不定方程 i = 1 N a i x i = C \sum_{i=1}^{N}a_ix_i=C 的非负整数解的数量。
  • 考虑将 C , x i C,x_i 用二进制表示,我们从高位向低位依次决策 x i x_i 的某一位是否为 1 1
  • 假设当前决策的是第 i i 位,那么当前决策的 i = 1 N a i x i \sum_{i=1}^{N}a_ix_i 应当为 2 i + 1 2^{i+1} 的倍数,并且剩余的尚未决策的位置上即使全部为 1 1 ,其总和也不会超过 2 2 i i = 1 N a i 2*2^i*\sum_{i=1}^{N}a_i
  • 因此,记 d p i , j , k dp_{i,j,k} 表示决策到第 i i 位,第 j j 个数, C i = 1 N a i x i 2 i = k \lfloor\frac{C-\sum_{i=1}^{N}a_ix_i}{2^i}\rfloor=k 的方案数,由上面的推理, k > 2 i = 1 N a i k>2*\sum_{i=1}^{N}a_i 的状态不会对答案产生贡献,不计之。
  • 时间复杂度 O ( N 2 M a x { s i } L o g C ) O(N^2*Max\{s_i\}*LogC)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 55;
const int MAXM = 505;
const int MAXS = 5e4 + 5;
const int P = 1e9 + 7;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
char s[MAXM];
int n, m, a[MAXN], bits[MAXM];
int dp[2][MAXN][MAXS];
void convert() {
	int len = strlen(s + 1);
	for (int i = 1; i <= len; i++)
		s[i] = s[i] - '0';
	reverse(s + 1, s + len + 1);
	while (len != 0) {
		bits[++m] = s[1] & 1, s[1] -= bits[m];
		for (int i = len; i >= 1; i--) {
			if (s[i] & 1) s[i - 1] += 10;
			s[i] /= 2;
		}
		while (len && s[len] == 0) len--;
	}
}
void update(int &x, int y) {
	x += y;
	if (x >= P) x -= P;
}
int main() {
	freopen("counting.in", "r", stdin);
	freopen("counting.out", "w", stdout);
	scanf("%d %s", &n, s + 1);
	int sum = 0;
	for (int i = 1; i <= n; i++)
		read(a[i]), sum += a[i];
	convert();
	dp[m & 1][0][1] = 1;
	for (int i = m, now = m & 1, dest = now ^ 1; i >= 1; i--, swap(now, dest)) {
		for (int j = 1; j <= n; j++)
		for (int k = 0; k <= 2 * sum; k++) {
			update(dp[now][j][k], dp[now][j - 1][k]);
			if (k + a[j] <= 2 * sum) update(dp[now][j][k], dp[now][j - 1][k + a[j]]);
		}
		memset(dp[dest], 0, sizeof(dp[dest]));
		for (int k = 0; k <= sum; k++)
			if (k * 2 + bits[i - 1] <= 2 * sum) update(dp[dest][0][k * 2 + bits[i - 1]], dp[now][n][k]); 
	}
	printf("%d\n", dp[1][n][0]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39972971/article/details/85224695
今日推荐