模拟赛 我的朋友们(分治NTT)

题意:
有一个长度为 n n 的序列,序列上每个位置有一个物品,每个物品在每次被询问到的时候有 p i p_i 的概率被拒绝,考虑这样一个过程:先取前 L L 个物品,然后每次对这 L L 个物品都询问,如果有 x x 个物品被拒绝,则将 x x 物品丢弃,然后从未被选取的物品中选取最靠前的 x x 个物品重复这个过程,当总物品数 < L \lt L 时停止,问期望询问多少组次。

n , L 1 e 5 n , L \leq 1e5
m o d 998244353 \bmod 998244353

可以写出 n 2 d p n^2dp
f i f_i 表示现在手上的物品是 [ i , i + L 1 ] [i,i+L-1] 时期望还需要多少步。
f i = j = 1 L f i , j ( [ x j ] k = i i + L 1 p k x + ( 1 p k ) ) 1 j = 0 L 1 ( 1 p i + j ) f_i = \frac {\sum_{j=1}^L f_{i,j}\left([x^j]\prod_{k=i}^{i+L-1}p_{k}x+(1-p_k)\right)}{1 - \prod_{j=0}^{L-1} (1 - p_{i+j})}

下面那个可以简单 O ( n ) O(n) 计算,
上面那个。。。
凭什么这个 d p dp 要平白无故地处理多个多项式还开 1 e 5 1e5?
你可以拿起笔,在草稿纸上涂涂画画,写下一堆形如“ H ( u , x ) = u k ( p k x + ( 1 p k ) ) H(u,x)=u^k(p_{k}x+(1-p_k)) ”的东西,然后拍案而起:“不会做,自闭了,我要去洗澡!”
你也可以平心静气,突然发现,唉,上面的多项式长得都差不多啊,然后安静地度过一天。

其实上面那个式子是可以分治 N T T NTT 的。
首先差卷积不太爽,我们将 p p 翻转,那么我们实际上也就是把 f f 也翻转了,最后的答案就是 f n f_n
F j ( x ) = i = 1 j f i x i F_j(x) = \sum_{i=1}^j f_ix^i
则在 j < L j \lt L F j ( x ) = 0 F_j(x) = 0
j L j \geq L F j ( x ) = [ x j ] F j 1 ( x ) P i ( x ) 1 j = 0 L 1 ( 1 p i j ) F_j(x) = \frac {[x^j]F_{j-1}(x)P_i(x)}{1 - \prod_{j=0}^{L-1} (1 - p_{i-j})}
其中 P i ( x ) = j = 0 L 1 p j x + ( 1 p j ) P_i(x) = \prod_{j=0}^{L-1} p_jx + (1-p_j)
然后考虑自顶向下的分治 N T T NTT
当前我们在 [ l , r ] [l,r] ,拥有
A l , r ( x ) = F l 1 × r L + 1 i l p i x + 1 p i A_{l,r}(x) = F_{l-1} \times \prod_{r-L+1 \leq i \leq l} p_ix + 1-p_i

B l , r ( x ) = r L + 1 i l p i x + 1 p i B_{l,r}(x) = \prod_{r-L+1 \leq i \leq l} p_ix + 1-p_i
那么可以发现如果 l = r l = r ,则 A l , l ( x ) = F l 1 ( x ) × P l ( x ) A_{l,l}(x) = F_{l-1}(x) \times P_l(x)
就是我们想要的东西,求个第 l l 项即可。
考虑怎么在走到 [ l , m i d ] [l,mid] [ m i d + 1 , r ] [mid+1,r] 的时候快速的维护出 A l , m i d , A m i d + 1 , r , B l , m i d , B m i d + 1 , r A_{l,mid},A_{mid+1,r},B_{l,mid},B_{mid+1,r}
r r 变成 m i d mid 时,可以发现我们可以直接
A l , m i d = A l , r × i = m i d + 1 r p i L x + 1 p i L A_{l,mid} = A_{l,r}\times \prod_{i=mid+1}^r p_{i-L}x + 1 - p_{i-L}
B l , m i d = B l , r × i = m i d + 1 r p i L x + 1 p i L B_{l,mid} = B_{l,r}\times \prod_{i=mid+1}^r p_{i-L}x + 1 - p_{i-L}
l l 变成 m i d + 1 mid+1 时,我们可以直接
A m i d + 1 , r = A l , r × i = l + 1 m i d + 1 p i L x + 1 p i L + ( F m i d ( x ) F l 1 ( x ) ) × B m i d + 1 , r A_{mid+1,r} = A_{l,r}\times \prod_{i=l+1}^{mid+1} p_{i-L}x + 1 - p_{i-L} + (F_{mid}(x) - F_{l-1}(x)) \times B_{mid+1,r}
B m i d + 1 , r = B l , r × i = l + 1 m i d 1 p i L x + 1 p i L B_{mid+1,r} = B_{l,r}\times \prod_{i=l+1}^{mid-1} p_{i-L}x + 1 - p_{i-L}
把后面的全部分治 N T T NTT 预处理出来即可简单转移、

但是当你写出上面的式子的时候你就 n a i v e naive 了,所谓细节调一年。
上面的式子其实应该是:
A l , m i d = A l , r × i = m i d + 1 min ( r , l + L ) p i L x + 1 p i L A_{l,mid} = A_{l,r}\times \prod_{i=mid+1}^{\min(r,l+L)} p_{i-L}x + 1 - p_{i-L}
B l , m i d = B l , r × i = m i d + 1 min ( r , l + L ) p i L x + 1 p i L B_{l,mid} = B_{l,r}\times \prod_{i=mid+1}^{\min(r,l+L)} p_{i-L}x + 1 - p_{i-L}
A m i d + 1 , r = A l , r × i = max ( r L + 1 , l + 1 ) m i d + 1 p i L x + 1 p i L + ( F m i d ( x ) F l 1 ( x ) ) × B m i d + 1 , r A_{mid+1,r} = A_{l,r}\times \prod_{i=\max(r-L+1,l+1)}^{mid+1} p_{i-L}x + 1 - p_{i-L} + (F_{mid}(x) - F_{l-1}(x)) \times B_{mid+1,r}
B m i d + 1 , r = B l , r × i = max ( r L + 1 , l + 1 ) m i d 1 p i L x + 1 p i L B_{mid+1,r} = B_{l,r}\times \prod_{i=\max(r-L+1,l+1)}^{mid-1} p_{i-L}x + 1 - p_{i-L}
然后你仔细思考这个取 max \max 和取 min \min
以取 min \min 为例:
l + L < r l+L \lt r 时,
l + L < m i d + 1 l+L \lt mid+1 则原多项式为 1 1
否则,只有一个 r r 可以使得 m i d + 1 l + L < r mid +1 \leq l+L \lt r
对于这种情况直接再暴力一个自下而上的分治 N T T NTT 再算一遍即可。(但是好像网上其他解法有些可以避开这种情况的高论。)
其他情况可以一个自下而上的分治 N T T NTT 搞定。

还有一些细节比如说 A l , r ( x ) A_{l,r}(x) 我们应该只保留次数在 [ x l ( r l + 1 ) , x r ] [x^{l-(r-l+1)} , x^r] 之间的项来保证复杂度正确。
对于 B l , r ( x ) B_{l,r}(x) 只保留次数在 [ x l , x r + ( r l + 1 ) ] [x^l,x^{r + (r-l+1)}] 之间的项。

A C   C o d e \mathcal AC \ Code

#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=(j),LIM=(k);i<=LIM;i++)
#define per(i,j,k) for(int i=(j),LIM=(k);i>=LIM;i--)
#define maxn 800005
#define Ct const
#define LL long long
#define pii pair<int,int>
#define vc vector
#define vi vc<int>
#define mod 998244353
#define pb push_back
#define mp make_pair
#define db double
using namespace std;

namespace IO{
	char cb[1<<16],*cs=cb,*ct=cb;
	#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<16,stdin),cs==ct)?0:*cs++)
	void read(int &res){
		char ch;bool f=0;
		for(;!isdigit(ch=getc());) if(ch=='-') f=1;
		for(res=ch-'0';isdigit(ch=getc());res=res*10+ch-'0');
		(f) && (res = -res);
	}
}

int Pow(int b,int k){ int r=1;for(;k;k>>=1,b=1ll*b*b%mod) if(k&1) r=1ll*r*b%mod; return r; }
int n,L,p[maxn],f[maxn],*F[maxn],*G[maxn];
int *P[maxn],*P2[maxn];

int Wl,Wl2,w[maxn],lg[maxn],inv[maxn],fac[maxn],invf[maxn];
void init(int n){
	for(Wl=1;n>=Wl<<1;Wl<<=1);
	int pw = Pow(3 , (mod-1) / (Wl2=Wl<<1));
	w[Wl] = inv[0] = inv[1] = fac[0] = fac[1] = invf[0] = invf[1] = 1;
	rep(i,Wl+1,Wl2) w[i] = 1ll * w[i-1] * pw % mod;
	per(i,Wl-1,1) w[i] = w[i<<1];
	rep(i,2,Wl2) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod,fac[i] = 1ll * fac[i-1] * i % mod,
		invf[i] = 1ll * invf[i-1] * inv[i] % mod , lg[i] = lg[i >> 1] + 1;
}
int upd(int x){ return x += x >> 31 & mod; }
void NTT(int *A,int n,int tp){
	static int r[maxn]={};
	static unsigned long long ar[maxn];
	if(tp ^ 1) reverse(A+1,A+n);
	rep(i,0,n-1) r[i] = r[i >> 1] >> 1 | (i&1) << lg[n] - 1 , ar[i] = upd(A[r[i]]);
	for(int L=1;L<n;L<<=1) for(int s=0,L2=L<<1;s<n;s+=L2) for(int k=s,x=L,t;k<s+L;k++,x++)
		t=w[x]*ar[k+L]%mod,ar[k+L]=ar[k]-t+mod,ar[k]+=t;
	rep(i,0,n-1) A[i] = ar[i] % mod;
	if(tp ^ 1) rep(i,0,n-1) A[i] = 1ll * A[i] * inv[n] % mod;
}
void Mul(int *A,int *B,int *C,int n,int m,int shift = 0,int cut = 0x3f3f3f3f){
	static int st[2][maxn];
	int L = 1 << lg[n+m] + 1;
	rep(i,0,L-1) st[0][i] = i <= n ? A[i] : 0 ,
		st[1][i] = i <= m ? B[i] : 0;
	NTT(st[0],L,1) , NTT(st[1],L,1);
	rep(i,0,L-1) st[0][i] = 1ll * st[0][i] * st[1][i] % mod;
	NTT(st[0],L,-1);
	rep(i,shift,min(cut,n+m)) C[i - shift] = st[0][i];
}
#define lc u<<1
#define rc lc|1

int *B[maxn];

void Build2(int u,int l,int r){
	B[u] = new int [r - l + 2];
	if(l == r) return (void)(B[u][0] = 1 - p[l] , B[u][1] = p[l]);
	int m = l + r >> 1;
	Build2(lc,l,m) , Build2(rc,m+1,r);
	Mul(B[lc],B[rc],B[u],m-l+1,r-m);
}

void Build(int u,int l,int r){
	P[u] = new int [r-l+2];
	P2[u] = new int [r-l+2];
	F[u] = new int [4 * (r-l+2)];
	G[u] = new int [4 * (r-l+2)];
	memset(P[u],0,sizeof (int) * (r-l+2));
	memset(P2[u],0,sizeof (int) * (r-l+2));
	memset(F[u],0,sizeof (int) * 4 * (r-l+2));
	memset(G[u],0,sizeof (int) * 4 * (r-l+2));
	if(l == r) return (void)( 
		P[u][0] = 1 - p[l+1] , P[u][1] = p[l+1] , 
		P2[u][0] = 1 - (l - L >= 1 ? p[l - L] : 0) , P2[u][1] = (l - L >= 1 ? p[l-L] : 0)
		);
	int m = l + r >> 1;
	Build(lc,l,m) , Build(rc,m+1,r);
	Mul(P[lc],P[rc],P[u],m-l+1,r-m);
	Mul(P2[lc],P2[rc],P2[u],m-l+1,r-m);	

	if(l + L < r){
		memset(P2[rc],0,4 * (r-m+1));
		if(l + L <= m){
			P2[rc][0] = 1;
		}
		else{
			Build2(1,max(m+1-L,1),l);
			rep(i,0,l-max(m+1-L,1)+1) P2[rc][i] = B[1][i];
		}
	}
	if(r - L + 1 > l + 1){	
		//printf("#%d %d\n",l,r);
		memset(P[lc],0,4 * (m-l+2));
		if(r - L + 1 > m + 1){
			P[lc][0] = 1;
		}
		else{
			Build2(1,r-L+1,m+1);
			rep(i,0,m+1-(r-L+1)+1) P[lc][i] = B[1][i];
		}
	}
}

void Solve(int u,int l,int r){
	if(l == r){
	//printf("@%d %d %d\n",l,r,F[u][1]);
		f[l] = 1ll * f[l] * (1 + F[u][1]) % mod;
		return;
	}
	int m = l + r >> 1;
	Mul(F[u],P2[rc],F[lc],2*(r-l+1)-1,r-m,r-m,r-m+(m-l+1)*2);
	Mul(G[u],P2[rc],G[lc],2*(r-l+1)-1,r-m,0,(m-l+1)*2);
	Solve(lc,l,m);
	static int t[maxn];
	Mul(G[u],P[lc],G[rc],2*(r-l+1)-1,m+1-l,0,(r-m)*2);
	Mul(F[u],P[lc],F[rc],2*(r-l+1)-1,m+1-l,(m+1-l)*2,(m+1-l)*2 + (r-m)*2);
	Mul(&f[l],G[rc],t,m-l,(r-m)*2);
	rep(i,0,(r-m)*2) F[rc][i+(l+r-2*m-1)] = (F[rc][i+(l+r-2*m-1)] + t[i]) % mod /*, printf("@%d %d %d\n",i,(l+r-2*m-1),t[i])*/;
	Solve(rc,m+1,r);
}

int main(){

	freopen("friends.in","r",stdin);
	freopen("friends.out","w",stdout);

	scanf("%d%d",&n,&L);
	int sm = 1;
	rep(i,1,n){
		int a,b;scanf("%d%d",&a,&b);
		p[i] = 1ll * a * Pow(b , mod - 2) % mod;
	}
	reverse(p+1,p+n+1);
	rep(i,1,n){
		sm = 1ll * sm * (1 - p[i]) % mod * (i - L >= 1 ? Pow(1 - p[i-L] , mod-2) : 1) % mod;
		if(i >= L) f[i] = Pow(1 - sm , mod-2);
	}
	init(n << 2);
	Build(1,1,n);
	G[1][0] = 1;
	Solve(1,1,n);
	printf("%d\n",(f[n]+mod)%mod);
}

猜你喜欢

转载自blog.csdn.net/qq_35950004/article/details/107771449