【LOJ3071】「2019 集训队互测 Day 2」神树大人挥动魔杖

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/89762619

【题目链接】

【思路要点】

  • w a y s i ways_i 表示从 1 1 号点走到 i i 号点的方案数,有 w a y s i = p × w a y s i 1 + q × w a y s i 2   ( i 2 ) ways_{i}=p\times ways_{i-1}+q\times ways_{i-2}\ (i\geq2)
  • 考虑容斥原理,我们强制一部分格子没有被踩过,计算在剩余格子中行走的方案数,乘上容斥系数计入答案。
  • f i f_i 表示从 1 1 走到 i i ,强制 i 1 i-1 未被踩过,带有容斥系数的方案数, g i g_i 表示 N = i N=i 时的答案,则有
    f i = j = 0 i 2 ( q × w a y s j + 1 ) M f i j 2 f_{i}=-\sum_{j=0}^{i-2}(q\times ways_{j+1})^Mf_{i-j-2}
    g i = j = 0 i ( w a y s j + 1 + q × w a y s j ) M f i j g_{i}=\sum_{j=0}^{i}(ways_{j+1}+q\times ways_j)^Mf_{i-j}
  • W a y s ( x ) = i 0 w a y s i x i Ways(x)=\sum_{i\geq0}ways_ix^i ,则有 W a y s ( x ) = p x W a y s ( x ) + q x 2 W a y s ( x ) + x Ways(x)=pxWays(x)+qx^2Ways(x)+x ,即 W a y s ( x ) = x 1 p x q x 2 Ways(x)=\frac{x}{1-px-qx^2}
  • W a y s ( x ) = x 1 p x q x 2 = A ( x ) 1 α x + B ( x ) 1 β x Ways(x)=\frac{x}{1-px-qx^2}=\frac{A(x)}{1-\alpha x}+\frac{B(x)}{1-\beta x} ,解得
    α = p + p 2 + 4 q 2 , β = p p 2 + 4 q 2 , A ( x ) = 1 p 2 + 4 q , B ( x ) = 1 p 2 + 4 q \alpha=\frac{p+\sqrt{p^2+4q}}{2},\beta=\frac{p-\sqrt{p^2+4q}}{2},A(x)=\frac{1}{\sqrt{p^2+4q}},B(x)=-\frac{1}{\sqrt{p^2+4q}}
  • 因此 w a y s i = A ( x ) α i + B ( x ) β i ways_i=A(x)\alpha^i+B(x)\beta^i
  • H ( x ) = i 0 w a y s i M x i H(x)=\sum_{i\geq0}ways_i^Mx^i ,则有 H ( x ) = i 0 j = 0 M ( M j ) A ( x ) j α i j B ( x ) M j β i ( M j ) x i H(x)=\sum_{i\geq0}\sum_{j=0}^{M}\binom{M}{j}A(x)^j\alpha^{ij}B(x)^{M-j}\beta^{i(M-j)}x^i
  • 进而 H ( x ) = j = 0 M ( M j ) A ( x ) j B ( x ) M j i 0 α i j β i ( M j ) x i = j = 0 M A ( x ) j B ( x ) M j ( M j ) 1 α j β M j x H(x)=\sum_{j=0}^{M}\binom{M}{j}A(x)^jB(x)^{M-j}\sum_{i\geq0}\alpha^{ij}\beta^{i(M-j)}x^i=\sum_{j=0}^{M}\frac{A(x)^jB(x)^{M-j}\binom{M}{j}}{1-\alpha^j\beta^{M-j}x}
  • 因此,记 H ( x ) = P ( x ) Q ( x ) H(x)=\frac{P(x)}{Q(x)} P ( x ) , Q ( x ) P(x),Q(x) 均为次数 O ( M ) O(M) 的多项式,且 Q ( x ) = i = 0 M ( 1 α j β M j x ) Q(x)=\prod_{i=0}^{M}(1-\alpha^j\beta^{M-j}x) ,其常数项为 1 1 。进而 H ( x ) H(x) 为线性递推数列,其递推式即为 1 Q ( x ) 1-Q(x)
  • F ( x ) = i 0 f i x i , G ( x ) = i 0 g i x i , C f ( x ) = i 0 ( q × w a y s i + 1 ) M x i , C g ( x ) = i 0 ( q × w a y s i + w a y s i + 1 ) M x i F(x)=\sum_{i\geq0}f_ix^i,G(x)=\sum_{i\geq0}g_ix^i,Cf(x)=\sum_{i\geq0}(q\times ways_{i+1})^Mx^i,Cg(x)=\sum_{i\geq0}(q\times ways_{i}+ways_{i+1})^Mx^i ,由于 C f ( x ) , C g ( x ) Cf(x),Cg(x) 均为递推式与 H ( x ) H(x) 相同的线性递推数列,它们同样可以写成 P ( x ) Q ( x ) \frac{P'(x)}{Q(x)} 的形式,其中 P ( x ) P'(x) 为次数 O ( M ) O(M) 的多项式。
  • C f ( x ) = C ( x ) Q ( x ) , C g ( x ) = D ( x ) Q ( x ) Cf(x)=\frac{C(x)}{Q(x)},Cg(x)=\frac{D(x)}{Q(x)} ,那么 F ( x ) = 1 1 + x 2 C ( x ) Q ( x ) , G ( x ) = F ( x ) D ( x ) Q ( x ) = D ( x ) Q ( x ) + x 2 C ( x ) F(x)=\frac{1}{1+x^2\frac{C(x)}{Q(x)}},G(x)=F(x)\frac{D(x)}{Q(x)}=\frac{D(x)}{Q(x)+x^2C(x)}
  • 可以发现, G ( x ) G(x) 也为线性递推数列,采用多项式乘法、带余除法优化线性递推即可。
  • 时间复杂度 O ( M L o g 2 M + M L o g M L o g N ) O(MLog^2M+MLogMLogN)

【代码】

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 5;
const int P = 998244353;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
	write(x);
	puts("");
}
int n, m, p, q, t;
inline int pls(const int &a, const int &b) {return (a + b >= P) ? (a + b - P) : (a + b); }
inline int mns(const int &a, const int &b) {return (a - b >= 0) ? (a - b) : (a - b + P); }
inline int mul(const int &a, const int &b) {return 1ll * a * b % P; }
int power(int x, int y) {
	if (y == 0) return 1;
	int tmp = power(x, y / 2);
	if (y % 2 == 0) return mul(tmp, tmp);
	else return mul(mul(tmp, tmp), x);
}
struct INT {int r, i; };
INT operator * (const INT &a, const int &b) {return (INT) {mul(a.r, b), mul(a.i, b)}; }
INT operator + (const INT &a, const INT &b) {return (INT) {pls(a.r, b.r), pls(a.i, b.i)}; }
INT operator - (const INT &a, const INT &b) {return (INT) {mns(a.r, b.r), mns(a.i, b.i)}; }
INT operator * (const INT &a, const INT &b) {return (INT) {pls(mul(a.r, b.r), mul(mul(a.i, b.i), t)), pls(mul(a.r, b.i), mul(b.r, a.i))}; }
namespace Poly {
	const int MAXN = 262144;
	const int P = 998244353;
	const int LOG = 25;
	const int G = 3;
	int power(int x, int y) {
		if (y == 0) return 1;
		int tmp = power(x, y / 2);
		if (y % 2 == 0) return 1ll * tmp * tmp % P;
		else return 1ll * tmp * tmp % P * x % P;
	}
	int N, Log, tmpa[MAXN], tmpb[MAXN], home[MAXN];
	bool initialized; int forward[MAXN], bckward[MAXN], inv[LOG];
	void init() {
		initialized = true;
		forward[0] = bckward[0] = inv[0] = 1;
		for (int len = 2, lg = 1; len <= MAXN; len <<= 1, lg++)
			inv[lg] = power(len, P - 2);
		int delta = power(G, (P - 1) / MAXN);
		for (int i = 1; i < MAXN; i++)
			forward[i] = bckward[MAXN - i] = 1ll * forward[i - 1] * delta % P;
	}
	void NTTinit() {
		for (int i = 0; i < N; i++) {
			int ans = 0, tmp = i;
			for (int j = 1; j <= Log; j++) {
				ans <<= 1;
				ans += tmp & 1;
				tmp >>= 1;
			}
			home[i] = ans;
		}
	}
	void NTT(int *a, int mode) {
		assert(initialized);
		for (int i = 0; i < N; i++)
			if (home[i] < i) swap(a[i], a[home[i]]);
		int *g;
		if (mode == 1) g = forward;
		else g = bckward;
		for (int len = 2, lg = 1; len <= N; len <<= 1, lg++) {
			for (int i = 0; i < N; i += len) {
				for (int j = i, k = i + len / 2; k < i + len; j++, k++) {
					int tmp = a[j];
					int tnp = 1ll * a[k] * g[MAXN / len * (j - i)] % P;
					a[j] = (tmp + tnp > P) ? (tmp + tnp - P) : (tmp + tnp);
					a[k] = (tmp - tnp < 0) ? (tmp - tnp + P) : (tmp - tnp);
				}
			}
		}
		if (mode == -1) {
			for (int i = 0; i < N; i++)
				a[i] = 1ll * a[i] * inv[Log] % P;
		}
	}
	void times(vector <int> &a, vector <int> &b, vector <int> &c) {
		assert(a.size() >= 1), assert(b.size() >= 1);
		int goal = a.size() + b.size() - 1;
		N = 1, Log = 0;
		while (N < goal) {
			N <<= 1;
			Log++;
		}
		for (unsigned i = 0; i < a.size(); i++)
			tmpa[i] = a[i];
		for (int i = a.size(); i < N; i++)
			tmpa[i] = 0;
		for (unsigned i = 0; i < b.size(); i++)
			tmpb[i] = b[i];
		for (int i = b.size(); i < N; i++)
			tmpb[i] = 0;
		NTTinit();
		NTT(tmpa, 1);
		NTT(tmpb, 1);
		for (int i = 0; i < N; i++)
			tmpa[i] = 1ll * tmpa[i] * tmpb[i] % P;
		NTT(tmpa, -1);
		c.resize(goal);
		for (int i = 0; i < goal; i++)
			c[i] = tmpa[i];
	}
	void timesabb(vector <int> &a, vector <int> &b, vector <int> &c) {
		assert(a.size() >= 1), assert(b.size() >= 1);
		int goal = a.size() + b.size() * 2 - 2;
		N = 1, Log = 0;
		while (N < goal) {
			N <<= 1;
			Log++;
		}
		for (unsigned i = 0; i < a.size(); i++)
			tmpa[i] = a[i];
		for (int i = a.size(); i < N; i++)
			tmpa[i] = 0;
		for (unsigned i = 0; i < b.size(); i++)
			tmpb[i] = b[i];
		for (int i = b.size(); i < N; i++)
			tmpb[i] = 0;
		NTTinit();
		NTT(tmpa, 1);
		NTT(tmpb, 1);
		for (int i = 0; i < N; i++)
			tmpa[i] = 1ll * tmpa[i] * tmpb[i] % P * tmpb[i] % P;
		NTT(tmpa, -1);
		c.resize(goal);
		for (int i = 0; i < goal; i++)
			c[i] = tmpa[i];
	}
	void getinv(vector <int> &a, vector <int> &b) {
		assert(a.size() >= 1), assert(a[0] != 0);
		b.clear(), b.push_back(power(a[0], P - 2));
		while (b.size() < a.size()) {
			vector <int> c, ta = a;
			ta.resize(b.size() * 2);
			timesabb(ta, b, c);
			b.resize(b.size() * 2);
			for (unsigned i = 0; i < b.size(); i++)
				b[i] = (2ll * b[i] - c[i] + P) % P;
		}
		b.resize(a.size());
	}
	void getdiv(vector <int> &a, vector <int> &b, vector <int> &q) {
		a.resize(max(a.size(), b.size()));
		reverse(a.begin(), a.end());
		reverse(b.begin(), b.end());
		vector <int> invb; getinv(b, invb);
		times(a, invb, q), q.resize(a.size() - b.size() + 1);
		reverse(a.begin(), a.end());
		reverse(b.begin(), b.end());
		reverse(q.begin(), q.end());
	}
	void getmod(vector <int> &a, vector <int> &b, vector <int> &r) {
		vector <int> q, p;
		getdiv(a, b, q);
		times(b, q, p), r.clear();
		for (unsigned i = 0; i < b.size() - 1; i++)
			r.push_back((a[i] - p[i] >= 0) ? (a[i] - p[i]) : (a[i] - p[i] + P));
	}
	void NTT(INT *a, int mode) {
		assert(initialized);
		for (int i = 0; i < N; i++)
			if (home[i] < i) swap(a[i], a[home[i]]);
		int *g;
		if (mode == 1) g = forward;
		else g = bckward;
		for (int len = 2, lg = 1; len <= N; len <<= 1, lg++) {
			for (int i = 0; i < N; i += len) {
				for (int j = i, k = i + len / 2; k < i + len; j++, k++) {
					INT tmp = a[j];
					INT tnp = a[k] * g[MAXN / len * (j - i)];
					a[j] = tmp + tnp;
					a[k] = tmp - tnp;
				}
			}
		}
		if (mode == -1) {
			for (int i = 0; i < N; i++)
				a[i] = a[i] * inv[Log];
		}
	}
	void times(vector <INT> &a, vector <INT> &b, vector <INT> &c) {
		assert(a.size() >= 1), assert(b.size() >= 1);
		int goal = a.size() + b.size() - 1;
		N = 1, Log = 0;
		while (N < goal) {
			N <<= 1;
			Log++;
		}
		static INT tmpa[MAXN], tmpb[MAXN];
		for (unsigned i = 0; i < a.size(); i++)
			tmpa[i] = a[i];
		for (int i = a.size(); i < N; i++)
			tmpa[i] = (INT) {0, 0};
		for (unsigned i = 0; i < b.size(); i++)
			tmpb[i] = b[i];
		for (int i = b.size(); i < N; i++)
			tmpb[i] = (INT) {0, 0};
		NTTinit();
		NTT(tmpa, 1);
		NTT(tmpb, 1);
		for (int i = 0; i < N; i++)
			tmpa[i] = tmpa[i] * tmpb[i];
		NTT(tmpa, -1);
		c.resize(goal);
		for (int i = 0; i < goal; i++)
			c[i] = tmpa[i];
	}
}
namespace LinearSequence {
	const int MAXN = 262144;
	const int P = 998244353;
	vector <int> mod, inv;
	int power(int x, long long y) {
		if (y == 0) return 1;
		int tmp = power(x, y / 2);
		if (y % 2 == 0) return 1ll * tmp * tmp % P;
		else return 1ll * tmp * tmp % P * x % P;
	}
	void times(vector <int> &a, vector <int> &b, vector <int> &res) {
		vector <int> tmp, q, p;
		Poly :: times(a, b, tmp);
		if (tmp.size() < mod.size()) {
			res = tmp;
			return;
		}
		reverse(tmp.begin(), tmp.end());
		Poly :: times(tmp, inv, q);
		q.resize(tmp.size() - inv.size() + 1);
		reverse(tmp.begin(), tmp.end());
		reverse(q.begin(), q.end());
		Poly :: times(mod, q, p), res.clear();
		for (unsigned i = 0; i < mod.size() - 1; i++)
			res.push_back((tmp[i] - p[i] >= 0) ? (tmp[i] - p[i]) : (tmp[i] - p[i] + P));
	}
	vector <int> work(ll n) {
		if (n == 0) {
			vector <int> ans;
			ans.push_back(1);
			return ans;
		}
		vector <int> tmp = work(n / 2), tnp = tmp, ans;
		if (n & 1) tmp.insert(tmp.begin(), 0);
		times(tmp, tnp, ans);
		return ans;
	}
	int query(long long n, vector <int> a, vector <int> h) {
		assert(h.size() >= a.size());
		if (a.size() == 1) return 1ll * power(a[0], n) * h[0] % P;
		mod = a, reverse(mod.begin(), mod.end());
		for (unsigned i = 0; i < mod.size(); i++)
			mod[i] = (mod[i] == 0) ? 0 : (P - mod[i]);
		mod.push_back(1);
		reverse(mod.begin(), mod.end());
		Poly :: getinv(mod, inv);
		reverse(mod.begin(), mod.end());
		vector <int> res = work(n);
		int ans = 0;
		for (unsigned i = 0; i < res.size(); i++)
			ans = (ans + 1ll * res[i] * h[i]) % P;
		return ans;
	}
}
INT alpha, beta;
INT power(INT x, int y) {
	if (y == 0) return (INT) {1, 0};
	INT tmp = power(x, y / 2);
	if (y % 2 == 0) return tmp * tmp;
	else return tmp * tmp * x;
}
vector <INT> work(int l, int r) {
	if (l == r) {
		vector <INT> ans;
		ans.push_back((INT) {1, 0});
		ans.push_back((INT) {0, 0} - power(alpha, l) * power(beta, m - l));
		return ans;
	}
	int mid = (l + r) / 2;
	vector <INT> a = work(l, mid), b = work(mid + 1, r), ans;
	Poly :: times(a, b, ans);
	return ans;
}
int ways[MAXN];
int main() { 
	read(n), read(m), read(p), read(q);
	t = pls(mul(p, p), mul(4, q));
	alpha = (INT) {p, 1} * ((P + 1) / 2);
	beta = (INT) {p, -1} * ((P + 1) / 2);
	Poly :: init(); vector <INT> tmp = work(0, m);
	vector <int> Q;
	for (int i = 0; i <= m + 1; i++) {
		Q.push_back(tmp[i].r);
		assert(tmp[i].i == 0);
	}
	ways[1] = 1;
	for (int i = 2; i <= m + 2; i++)
		ways[i] = pls(mul(ways[i - 1], p), mul(ways[i - 2], q));
	vector <int> f, g;
	for (int i = 0; i <= m + 1; i++) {
		f.push_back(power(mul(q, ways[i + 1]), m));
		g.push_back(power(pls(mul(q, ways[i]), ways[i + 1]), m));
	}
	vector <int> F, G, invF, Ans;
	Poly :: times(Q, f, F);
	Poly :: times(Q, g, G);
	for (int i = m + 1; i >= 0; i--)
		F[i] = pls(Q[i], (i >= 2) ? F[i - 2] : 0);
	Poly :: getinv(F, invF);
	Poly :: times(G, invF, Ans);
	Ans.resize(m + 2);
	for (int i = 0; i <= m; i++)
		F[i] = mns(0, F[i + 1]);
	F.resize(m + 1);
	writeln(LinearSequence :: query(n, F, Ans));
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39972971/article/details/89762619