Codeforces 722E 组合数学 DP

题意:有一个n * m的棋盘,你初始在点(1, 1),你需要去点(n, m)。你初始有s分,在这个棋盘上有k个点,经过一次这个点分数就会变为s / 2(向上取整),问从起点到终点的分数的数学期望是多少?

思路:按照套路,先把这k个点按照pair的方式进行排序,设dp[i][j]为从起点到点i之前经过了至少j个减分点,到点i的数学期望。那么所有在它之前的可以向它转移的点向它转移。那么dp[i][j] = Σ(dp[u][j - 1] - dp[u][j]) * g(u, i)。其中g(u, i)是u, i之间没有限制条件的走法数目,用组合数学的方法计算即可。这样相当于是前面恰好走过j个点 + 可能走过大于一个点的方式转移过来,这样可以保证计数的不重不漏。

代码:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define db double
#define LL long long
#define pii pair<int, int>
using namespace std;
const int maxn = 200010;
const LL mod = 1e9 + 7;
LL dp[2010][40];
pii a[2010];
LL v[maxn], inv[maxn];
LL qpow(LL x, LL y) {
	LL ans = 1;
	for (; y; y >>= 1) {
		if(y & 1) ans = (ans * x) % mod;
		x = (x * x) % mod;
	}
	return ans;
}
void init(int n) {
	v[0] = 1;
	for (int i = 1; i <= n; i++) {
		v[i] = (v[i - 1] * i) % mod;
	}
	inv[n] = qpow(v[n], mod - 2);
	for (int i = n - 1; i >= 0; i--) {
		inv[i] = (inv[i + 1] * (i + 1)) % mod;
	}
}
LL C(LL n, LL m) {
	return (((v[n] * inv[m]) % mod) * inv[n - m]) % mod;
}
LL cal(int x, int y) {
	LL tmp = abs(a[y].first - a[x].first), tmp1 = tmp + (a[y].second - a[x].second);
	return C(tmp1, tmp);
}
LL b[50];
int main() {
	int n, m, k, t;
	scanf("%d%d%d%d", &n, &m, &k, &t);
	init(n + m);
	for (int i = 1; i <= k; i++) {
		scanf("%d%d", &a[i].first, &a[i].second);
	}
	int lim = 0;
	while(t > 1) {
		b[++lim] = t;
		t = (t + 1) / 2;
	}
	b[++lim] = 1;
	b[lim + 1] = 1;
	sort(a + 1, a + 1 + k);
	k++;
	a[k] = make_pair(n, m);
	for (int i = 1; i <= k; i++) {
		dp[i][0] = C(a[i].first + a[i].second - 2, a[i].first - 1);
	}
	LL ans = 0;
	for (int j = 1; j <= lim; j++) {
		for (int i = 1; i <= k; i++) {
			for (int t = 1; t < i; t++) {
				if(a[t].first <= a[i].first && a[t].second <= a[i].second) {
					LL tmp1 = (dp[t][j - 1] - dp[t][j] + mod) % mod;
					LL tmp2 = cal(t, i);
					assert(tmp1 >= 0);
					assert(tmp2 >= 0);
					dp[i][j] += (tmp1 * tmp2) % mod;
					dp[i][j] %= mod;
				}
			}
		}
	}
	for (int i = 0; i <= lim; i++) {
		ans = (ans + (((dp[k][i] - dp[k][i + 1] + mod) % mod) * b[i + 1]) % mod) % mod;
	}
	ans = (ans * qpow(C(n + m - 2, n - 1), mod - 2)) % mod;
	printf("%lld\n", ans);
} 

  

猜你喜欢

转载自www.cnblogs.com/pkgunboat/p/11429676.html
今日推荐