#LOJ2541. 「PKUWC2018」猎人杀(分治+NTT)

版权声明:_ https://blog.csdn.net/lunch__/article/details/85268627

题意

n n 个人,每个人有一个权值 w i w_i ,每次随机杀一个人,杀第 i i 个人的概率是 w i j [ j   i s   a l i v e ] \frac{w_i}{\sum_j[j\ is \ alive]} ,求第一个人最后一个死的概率,对 998244353 998244353 取模。

题解

这题不是很难但是我自己太sb了所以看了很久。

第一个人最后一个死就代表恰好有 0 0 个人在第一个人之后死,算这个是个很套路的东西用容斥转化为至少,那么设 f ( S ) f(S) 为至少有 S S 集合的人在第一个人之后死的概率,那么 a n s = ( 1 ) S f ( S ) ans=\sum(-1)^{|S|}f(S) ,现在我们转化成了求 f ( S ) f(S) ,实际上这个东西我们可以写出一个式子:
f ( S ) = w 1 w 1 + w [ w S ] f(S)=\frac{w_1}{w_1+\sum w[w \in S]} 这个东西理解起来也不是很难,每次只有选择 S S 集合内的人击杀或者 1 1 击杀的时候才会影响他们的相对死亡顺序,而选中这里面的人击杀时必须要击杀第一个人才满足 S S 集合都在 1 1 之后死,那么我们发现 f ( S ) f(S) 只与 w [ w S ] \sum_w[w\in S] 有关,又因为题目中的条件有 w 1 0 5 \sum w\le10^5 ,我们可以考虑计算出每个值的容斥系数,可以直接设 f [ i ] [ j ] f[i][j] 表示考虑前 i i 个人,权值和为 j j 时的容斥系数,用背包的方法转移就是 f [ i ] [ j ] = f [ i 1 ] [ j ] f [ i 1 ] [ j w i ] f[i][j] = f[i - 1][j] - f[i - 1][j - w_i] ,这样子就可以得到 50 50 分的好成绩了,但实际上由于数据比较水,在LOJ上可以通过 80 80 分。

#include <bits/stdc++.h>

#define x first
#define y second
#define pb push_back
#define mp make_pair
#define inf (0x3f3f3f3f)

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

template<class T>inline T read(T &_) {
	T __ = getchar(), ___ = 1; _ = 0;
	for (; !isdigit(__); __ = getchar()) if (__ == '-') ___ = -1;
	for (; isdigit(__); __ = getchar()) _ = (_ << 3) + (_ << 1) + (__ ^ 48);
	return _ *= ___;
}

template<class T>inline bool chkmax(T &_, T __) { return _ < __ ? _ = __, 1 : 0; }
template<class T>inline bool chkmin(T &_, T __) { return _ > __ ? _ = __, 1 : 0; }

inline void proStatus() {
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>());
}

const int N = 1 << 18 | 1; 
const int mod = 998244353;

int n, sum, ans, w[N], f[N];

inline int add(int x, int y) { return (x += y) < mod ? x : x - mod; }

inline int qpow(int _, int __) {
	int ___ = 1; 
	for (; __; __ >>= 1, _ = 1ll * _ * _ % mod) 
		if (__ & 1) ___ = 1ll * ___ * _ % mod;
	return ___;
}

int main() {
#ifdef ylsakioi
	freopen("2541.in", "r", stdin);
	freopen("2541.out", "w", stdout);
#endif

	read(n), f[0] = 1;
	for (int i = 1; i <= n; ++ i) 
		sum = add(sum, read(w[i]));
	sum = add(sum, mod - w[1]);
	for (int i = 2; i <= n; ++ i) 
		for (int j = sum; j >= w[i]; -- j) 
			f[j] = add(f[j], (mod - f[j - w[i]]));
	for (int i = 0; i <= sum; ++ i) 
		ans = add(ans, 1ll * f[i] * qpow(i + w[1], mod - 2) % mod);
	printf("%lld\n", 1ll * ans * w[1] % mod);

	return 0;
}

我们设出这个东西的生成函数,第 i i a i x i a_ix^i 代表当前 d p i = a i dp_i=a_i ,那么每次转移就是乘上 ( 1 x w i ) (1-x^{w_i}) ,分别代表是否把当前这个人加入集合的决策,那么这个生成函数就是:
i = 2 n ( 1 x w i ) \prod_{i = 2}^n(1-x^{w_i})
这个东西直接分治,合并两个区间的信息的时候用NTT计算就好了,这个东西开数组有点麻烦,那么我们可以类似于线段树动态开点回收空间的方法一样,用完以后利用以前的空间,分治出最深的一条链深度应该是 log n \log n 的,那么我们开 log n \log n 个数组就好了,复杂度 O ( n log 2 n ) O(n\log ^2n)

Codes

#include <bits/stdc++.h>

#define x first
#define y second
#define pb push_back
#define mp make_pair
#define inf (0x3f3f3f3f)

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

template<class T>inline T read(T &_) {
	T __ = getchar(), ___ = 1; _ = 0;
	for (; !isdigit(__); __ = getchar()) if (__ == '-') ___ = -1;
	for (; isdigit(__); __ = getchar()) _ = (_ << 3) + (_ << 1) + (__ ^ 48);
	return _ *= ___;
}

template<class T>inline bool chkmax(T &_, T __) { return _ < __ ? _ = __, 1 : 0; }
template<class T>inline bool chkmin(T &_, T __) { return _ > __ ? _ = __, 1 : 0; }

inline void proStatus() {
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>());
}

const int N = 1 << 18 | 1; 
const int mod = 998244353;

int w[N], f[N], A[N], B[N], rev[N], S[N], tp[33][N], cnt = -1;

inline int add(int x, int y) { return (x += y) < mod ? x : x - mod; }

inline int qpow(int _, int __) {
	int ___ = 1; 
	for (; __; __ >>= 1, _ = 1ll * _ * _ % mod) 
		if (__ & 1) ___ = 1ll * ___ * _ % mod;
	return ___;
}

inline void NTT(int *a, int n, int fh) {
	for (int i = 0; i < n; ++ i) 
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int Wn, limit = 2; limit <= n; limit <<= 1) {
		Wn = qpow(fh ^ 1 ? qpow(3, mod - 2) : 3, (mod - 1) / limit);
		for (int W = 1, j = 0; j < n; j += limit, W = 1) 
			for (int i = j; i < j + (limit >> 1); ++ i, W = 1ll * W * Wn % mod) {
				int a1 = a[i], a2 = 1ll * a[i + (limit >> 1)] * W % mod; 
				a[i] = add(a1, a2), a[i + (limit >> 1)] = add(a1, mod - a2);
			}
	}
	if (fh ^ 1) for (int inv = qpow(n, mod - 2), i = 0; i < n; ++ i) 
		a[i] = 1ll * a[i] * inv % mod;
}

inline void calc(int *a, int *b, int *c, int limit) {
	NTT(a, limit, 1), NTT(b, limit, 1);
	for (int i = 0; i < limit; ++ i) 
		c[i] = 1ll * a[i] * b[i] % mod;
	NTT(c, limit, -1);
}

inline void Solve(int l, int r, int *a) {
	if (l == r) return (void) (a[0] = 1, a[w[l]] = mod - 1);
	int mid = (l + r) >> 1, limit = 1, k = 0, a1 = ++ cnt, a2 = ++ cnt; 
	Solve(l, mid, tp[a1]), Solve(mid + 1, r, tp[a2]);
	for (; limit <= S[r] - S[l - 1]; ++ k) limit <<= 1; 
	for (int i = 0; i < limit; ++ i) 
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
	calc(tp[a1], tp[a2], a, limit), cnt -= 2;
	for (int i = 0; i < limit; ++ i) tp[a1][i] = tp[a2][i] = 0;
}

int main() {
#ifdef ylsakioi
	freopen("2541.in", "r", stdin);
	freopen("2541.out", "w", stdout);
#endif

	int n, ans = 0; 

	read(n);
	for (int i = 1; i <= n; ++ i) 
		S[i] = S[i - 1] + read(w[i]);
	Solve(2, n, f);
	for (int i = 0; i <= S[n] - S[1]; ++ i) 
		ans = add(ans, 1ll * f[i] * qpow(w[1] + i, mod - 2) % mod);
	printf("%lld\n", 1ll * ans * w[1] % mod);

	return 0;
}

猜你喜欢

转载自blog.csdn.net/lunch__/article/details/85268627