题意
小Z是养鸽子的人。一天,小Z给鸽子们喂玉米吃。一共有\(n\)只鸽子,小Z每秒会等概率选择一只鸽子并给他一粒玉米。一只鸽子饱了当且仅当它吃了的玉米粒数量\(≥k\)。 小Z想要你告诉他,期望多少秒之后所有的鸽子都饱了。
数据范围: \(n≤50,k≤1000\),答案模\(998244353\)输出
题解
显然非常适合min-max容斥.令 \(f(n)\) 为有 \(n\) 只鸽子,将其中一只喂到饱的期望次数,就得到:
\[ ans = \sum_{i = 0} ^ {n} (-1) ^ {i + 1} \dbinom{n}{i} \frac{n}{i} f(i) \]
要乘以 \(\frac{n}{i}\) 是因为期望投喂 \(\frac{n}{i}\) 次才能有一次投喂到\(i\)只鸽子中的一只.
来大力计算 \(f(n)\) ,假设被喂饱的是第一只鸽子,总共投喂了 \(j + k\) 次,第 \(j+k\) 次投喂把第 \(1\) 只喂饱了. \(g(n,c)\) 表示 \(n\) 只鸽子,喂了 \(c\) 颗玉米,没有一只被喂饱的方案数.因为 \(n\) 只鸽子都可能被喂饱,最后还要乘以 \(n\) :
\[ f(n) = n\sum_{j} \frac{(j + k)\dbinom{j+k-1}{j}g(n-1,j)}{n^{j+k}} \]
而 \(g(n,c)\) 可以由EGF得到:
\[ g(n,c) = ((\sum_{i = 0} ^ {k - 1} \frac{x^i}{i!}) ^ n [x^c]) c! \]
令 \(g = \sum_{i = 0} ^ {k - 1} \frac{x^i}{i!}\) ,暴力算出它的前 \(n\) 次方就可以在 \(O(nk^2log(nk))\) 的复杂度内解决本题.
然而还有更优美的做法.
\(e^x = \sum_{i = 0} ^ {\infty} \frac{x^i}{i!}\) ,而 \((e^x)' = e^x\)
对于\(g\),可以类似得到 \(g' = g - \frac{x^{k - 1}}{(k - 1)!}\)
从而
\[ \begin{split} (g^n)' &= g^{n - 1} g' \\ &= (g^{n-1}) (g - \frac{x^{k-1}}{(k-1)!}) \\ &= g^n - \frac{x^{k-1}}{(k-1)!} g^{n-1} \end{split} \]
从中可以得出 \(g^n[x^n]\) 和 \(g^n[x^{n-1}]\) 之间的关系, \(O(1)\) 计算一项,总复杂度 \(O(nk^2)\)
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int N = 5e4 + 10;
const int mod = 998244353;
template <typename T> T read(T &x) {
int f = 0;
register char c = getchar();
while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = (x << 3) + (x << 1) + (c ^ 48);
if (f) x = -x;
return x;
}
namespace Comb {
const int Maxn = 1e6 + 10;
int fac[Maxn], fav[Maxn], inv[Maxn];
void comb_init() {
fac[0] = fav[0] = 1;
inv[1] = fac[1] = fav[1] = 1;
for (int i = 2; i < N; ++i) {
fac[i] = 1LL * fac[i - 1] * i % mod;
inv[i] = 1LL * -mod / i * inv[mod % i] % mod + mod;
fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
}
}
inline int C(int x, int y) {
if (x < y || y < 0) return 0;
return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}
inline int Qpow(int x, int p) {
int ans = 1;
for (; p; p >>= 1) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
}
return ans;
}
inline int Inv(int x) {
return Qpow(x, mod - 2);
}
inline void upd(int &x, int y) {
(x += y) >= mod ? x -= mod : 0;
}
inline int add(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}
inline int dec(int x, int y) {
return (x -= y) < 0 ? x + mod : x;
}
}
using namespace Comb;
int n, k;
int f[51], g[51][N];
int main() {
comb_init();
read(n); read(k);
g[0][0] = 1;
for (int i = 0; i < k; ++i) {
g[1][i] = fav[i];
}
for (int i = 2; i <= n; ++i) {
g[i][0] = 1;
for (int j = 1; j <= i * (k - 1); ++j) {
g[i][j] = 1LL * i * g[i][j - 1] % mod;
if (j >= k) g[i][j] = dec(g[i][j], 1LL * i * fav[k - 1] % mod * g[i - 1][j - k] % mod);
g[i][j] = 1LL * g[i][j] * inv[j] % mod;
}
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
int pw = Qpow(inv[i], k);
for (int j = 0; j <= (i - 1) * (k - 1); ++j) {
int d = 1LL * (j + k) * C(j + k - 1, j) % mod * g[i - 1][j] % mod * fac[j] % mod * pw % mod;
upd(f[i], 1LL * i * d % mod);
pw = 1LL * pw * inv[i] % mod;
}
int del = 1LL * C(n, i) * n % mod * inv[i] % mod * f[i] % mod;
if (!(i & 1)) del = mod - del;
upd(ans, del);
}
cout << ans << endl;
return 0;
}