[BZOJ5292][Bjoi2018]治疗之雨(期望DP+高斯消元)

Address

洛谷P4457
BZOJ5292
LOJ#2513

Solution

首先,一个显然的 DP 状态:
f [ i ] f[i] 表示第一个数当前为 i i ,将其变成 0 0 的期望步数。
边界当然是 f [ 0 ] = 0 f[0]=0
讨论一波转移:
P ( i , x ) P(i,x) 表示当第一个数为 i i 时, k k 轮减操作让第一个数减少 x x 的概率。
这样转移就很显然了:
i < n i<n 时:
f [ i ] = 1 + 1 m + 1 j = 0 i + 1 P ( i + 1 , j ) × f [ i j + 1 ] + m m + 1 j = 0 i P ( i , j ) × f [ i j ] f[i]=1+\frac 1{m+1}\sum_{j=0}^{i+1}P(i+1,j)\times f[i-j+1]+\frac m{m+1}\sum_{j=0}^iP(i,j)\times f[i-j]
i = n i=n 时:
f [ i ] = 1 + j = 0 i P ( i , j ) × f [ i j ] f[i]=1+\sum_{j=0}^iP(i,j)\times f[i-j]
要解决两个小问题:
(1) P ( i , j ) P(i,j) 的值。
分下类: 当 j < i j<i 时,相当于在 k k 次操作中选出 j j 次操作对第一个数进行,剩下的 k j k-j 次操作对剩下的 m m 个数进行。
所以:
P ( i , j ) = { C k j × m k j ( m + 1 ) k j < i 1 k = 0 i 1 P ( i , k ) j = i P(i,j)=\begin{cases}\frac{C_k^j\times m^{k-j}}{(m+1)^k}&j<i\\1-\sum_{k=0}^{i-1}P(i,k)&j=i\end{cases}
特别地,如果 k < j k<j P ( i , j ) = 0 P(i,j)=0
(2) 转移的后效性。
把每个 f [ i ] f[i] 当作一个未知变量,使用高斯消元解方程。
但这样复杂度是 O ( T n 3 ) O(Tn^3) 的。
发现系数矩阵长这个样子:
[ X 0 0 0 0 0 0 0 X X X 0 0 0 0 0 X X X X 0 0 0 0 X X X X X 0 0 0 X X X X X X 0 0 X X X X X X X 0 X X X X X X X X X X X X X X X X X X ] \begin{bmatrix}X&0&0&0&0&0&0&\dots&0\\X&X&X&0&0&0&0&\dots&0\\X&X&X&X&0&0&0&\dots&0\\X&X&X&X&X&0&0&\dots&0\\X&X&X&X&X&X&0&\dots&0\\X&X&X&X&X&X&X&\dots&0\\\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\X&X&X&X&X&X&X&X&X\\X&X&X&X&X&X&X&X&X\end{bmatrix}
从第一列到第 n + 1 n+1 列分别表示 f [ 0 ] f[0] f [ n ] f[n] ,第一行到第 n + 1 n+1 行分别表示 f [ 0 ] f[0] f [ n ] f[n] 的转移。
这矩阵已经非常接近于下三角矩阵。
我们只需要从最后一行开始网上,对于第 i i i > 2 i>2 )行,只需要用第 i i 行去消第 i i 行使得第 i i 行第 i + 1 i+1 列为 0 0 即可。
这样系数矩阵就变成了下三角矩阵,从 f [ 0 ] f[0] 开始一一代入即可。
注:如果出现了除以 0 0 的情况则方程组无解,输出 1 -1
时间复杂度 O ( T n 2 ) O(Tn^2)

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)

inline int read()
{
	int res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	return bo ? ~res + 1 : res;
}

template <class T>
T Min(T a, T b) {return a < b ? a : b;}

const int N = 1505, ZZQ = 1e9 + 7;

int n, p, m, k, inv[N], f[N][N], pw[N], C[N], a[N];

int qpow(int a, int b)
{
	int res = 1;
	while (b)
	{
		if (b & 1) res = 1ll * res * a % ZZQ;
		a = 1ll * a * a % ZZQ;
		b >>= 1;
	}
	return res;
}

void work()
{
	int i, j, alls, orz, tmp, rp;
	n = read(); p = read(); m = read(); k = read();
	orz = qpow(m + 1, ZZQ - 2);
	alls = qpow(qpow(m + 1, k), ZZQ - 2);
	C[0] = 1;
	For (i, 1, n) C[i] = 1ll * C[i - 1] * inv[i] % ZZQ * (k - i + 1) % ZZQ;
	For (i, 0, Min(n, k))
		pw[i] = 1ll * qpow(m, k - i) * C[i] % ZZQ * alls % ZZQ;
	f[0][0] = 1; f[0][n + 1] = 0;
	For (i, 1, n) f[0][i] = 0;
	For (i, 1, n)
	{
		For (j, 0, n + 1) f[i][j] = 0;
		f[i][n + 1] = 1; f[i][i] = rp = 1;
		For (j, 0, i)
		{
			tmp = j <= k ? 1ll * pw[j] * (i == n ? 0 : orz) % ZZQ : 0;
			f[i][i - j + 1] -= tmp; rp -= tmp;
			if (f[i][i - j + 1] < 0) f[i][i - j + 1] += ZZQ;
			if (rp < 0) rp += ZZQ;
		}
		if (i < n) f[i][0] -= 1ll * rp * orz % ZZQ;
		if (f[i][0] < 0) f[i][0] += ZZQ;
		rp = 1;
		For (j, 0, i - 1)
		{
			tmp = j <= k ? 1ll * pw[j]
				* (i == n ? 1 : 1ll * m * orz % ZZQ) % ZZQ : 0;
			f[i][i - j] -= tmp; rp -= tmp;
			if (f[i][i - j] < 0) f[i][i - j] += ZZQ;
			if (rp < 0) rp += ZZQ;
		}
		f[i][0] -= i == n ? rp : 1ll * rp * m % ZZQ * orz % ZZQ;
		if (f[i][0] < 0) f[i][0] += ZZQ;
	}
	Rof (i, n, 2)
	{
		if (!f[i][i]) return (void) puts("-1");
		int tmp = qpow(f[i][i], ZZQ - 2);
		For (j, 0, n + 1) f[i][j] = 1ll * f[i][j] * tmp % ZZQ;
		tmp = f[i - 1][i];
		For (j, 0, n + 1)
		{
			f[i - 1][j] -= 1ll * f[i][j] * tmp % ZZQ;
			if (f[i - 1][j] < 0) f[i - 1][j] += ZZQ;
		}
	}
	if (!f[1][1]) return (void) puts("-1");
	tmp = qpow(f[1][1], ZZQ - 2);
	For (i, 0, n + 1) f[1][i] = 1ll * f[1][i] * tmp % ZZQ;
	a[0] = 0;
	For (i, 1, p)
	{
		a[i] = f[i][n + 1];
		For (j, 0, i - 1)
		{
			a[i] -= 1ll * f[i][j] * a[j] % ZZQ;
			if (a[i] < 0) a[i] += ZZQ;
		}
	}
	printf("%d\n", a[p]);
}

int main()
{
	int i, T = read();
	inv[1] = 1;
	For (i, 2, 1500) inv[i] = 1ll * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ;
	while (T--) work();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xyz32768/article/details/83217209