LOJ565 mathematican的二进制

版权声明:随意转载,愿意的话提一句作者就好了 https://blog.csdn.net/stone41123/article/details/84567325

Link

Difficulty

算法难度7,思维难度6,代码难度6

Description

一个初始为 0 0 的二进制数,有 m m 次操作。

i i 次操作是将这个二进制数加上 2 a i 2^{a_i} 。这个操作以 p i p_i 的概率执行。

如果某次操作执行了并且修改了二进制数的 k k 位,那么它会带来 k k 的代价。

问代价和的期望,答案对 998244353 998244353 取模。

n = max a i 1 0 5 , m 2 × 1 0 5 n=\operatorname{max}a_i\le10^5,m\le 2\times 10^5

Solution

这题有两类部分分,一类是 n , m 3000 n,m\le 3000 ,另一类是 n = 1 , a i = 0 n=1,a_i=0

首先考虑第一类,我们需要一个平方级别的算法。

考虑每位的贡献,我们可以设计一个 d p ( i , j ) dp(i,j) 代表第 i i 位改变 j j 次的概率。

转移的话,首先是加入一个 2 i 2^i 的概率 p p 的修改: d p ( i , j ) = d p ( i , j 1 ) × p + d p ( i , j ) × ( 1 p ) dp'(i,j)=dp(i,j-1)\times p+dp(i,j)\times (1-p)

还有从 i i 位转移到 i + 1 i+1 位所造成的影响: d p ( i + 1 , j ) = d p ( i , 2 × j ) + d p ( i , 2 × j + 1 ) dp(i+1,j)=dp(i,2\times j)+dp(i,2\times j+1)

扫描二维码关注公众号,回复: 4485451 查看本文章

然后对每一位分别统计贡献即可,时间复杂度 O ( m 2 + ( n + l o g m ) m ) O(m^2+(n+logm)m)

接着来考虑 a i a_i 全为 0 0 的情况,我们发现一个操作会有 p p 的概率贡献 1 1 的改变次数, 1 p 1-p 的概率不贡献,那么就可以表示成 1 p + p x 1-p+px 这样一个多项式。

那么我们只需要一次分治NTT算出来这个式子即可 ( 1 p i + p i x ) \prod (1-p_i+p_ix) ,然后就可以逐位讨论贡献了,套用上面的 d p dp 式子,由于最多影响到 l o g m logm 级别的位数,那么总复杂度 O ( m l o g 2 m + m l o g m ) O(mlog^2m+mlogm)

我们考虑综合这两个做法来推出正解。

我们考虑直接用多项式来记录当前这一位的 d p dp 状态。

每次加入一位上的修改的时候,我们可以直接用分治NTT,然后再用多项式乘法将原来的状态乘上这一次的变化量,这样子做复杂度很容易证明,由于每一个修改最多影响到 O ( l o g m ) O(logm)​ 位,那么总复杂度仍然是 O ( m l o g 2 m ) O(mlog^2m)​

代码中有完整的部分分做法。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define LL long long
using namespace std;
inline int read(){
	int x=0,f=1;char ch=' ';
	while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return f==1?x:-x;
}
const int N=6e5+5,mod=998244353,g=3,gi=(mod+1)/g;
inline int ksm(int a,int n){
	int ans=1;
	while(n){
		if(n&1)ans=(LL)ans*a%mod;
		a=(LL)a*a%mod;
		n>>=1;
	}
	return ans;
}
int n,m,ans,tot;
int dp[N],tmp[N],head[N],Next[N],P[N];
inline void addmsg(int x,int p){
	P[++tot]=p;
	Next[tot]=head[x];
	head[x]=tot;
}
int R[N],bin[1<<21];
inline void NTT(int *a,int n,int f){
	int L=bin[n];
	for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	for(int i=0;i<n;++i)if(i<R[i])swap(a[i],a[R[i]]);
	for(int i=1;i<n;i<<=1){
		int wn=ksm(f==1?g:gi,(mod-1)/(i<<1));
		for(int j=0;j<n;j+=(i<<1)){
			int w=1;
			for(int k=0;k<i;++k,w=(LL)w*wn%mod){
				int x=a[j+k],y=(LL)w*a[j+k+i]%mod;
				a[j+k]=(x+y)%mod;
				a[j+k+i]=(x-y+mod)%mod;
			}
		}
	}
	if(f==-1){
		int num=ksm(n,mod-2);
		for(int i=0;i<n;++i)a[i]=(LL)a[i]*num%mod;
	}
}
int a[N],cnt,len,b[N],c[N];
vector<int> t[N<<1];
inline void build(int rt,int l,int r){
	t[rt].clear();
	if(l==r){
		t[rt].push_back((1-a[l]+mod)%mod);
		t[rt].push_back(a[l]);
		return;
	}
	int mid=(l+r)>>1;
	build(rt<<1,l,mid);
	build(rt<<1|1,mid+1,r);
	int len1=t[rt<<1].size(),len2=t[rt<<1|1].size();
	for(int i=0;i<len1;++i)b[i]=t[rt<<1][i];
	for(int i=0;i<len2;++i)c[i]=t[rt<<1|1][i];
    int n=r-l+2;
	for(len=1;len<=n;len<<=1);
	for(int i=len1;i<len;++i)b[i]=0;
	for(int i=len2;i<len;++i)c[i]=0;
	NTT(b,len,1);NTT(c,len,1);
	for(int i=0;i<len;++i)b[i]=(LL)b[i]*c[i]%mod;
	NTT(b,len,-1);
	for(int i=0;i<n;++i)t[rt].push_back(b[i]);
}
inline void solve(){
	for(int i=0;i<=20;++i)bin[1<<i]=i;
	for(int i=head[0];i;i=Next[i])a[++cnt]=P[i];
	build(1,1,cnt);
	n+=ceil(log(m)/log(2));
	for(int i=1;i<=cnt;++i)a[i]=t[1][i];
	for(int x=0;x<=n;++x){
		for(int i=0;i<=cnt;++i)ans=(ans+(LL)i*a[i])%mod;
		for(int i=1;i<=cnt;++i)tmp[i>>1]=(tmp[i>>1]+a[i])%mod;
		for(int i=0;i<=cnt;++i)a[i]=tmp[i],tmp[i]=0;
		cnt>>=1;
	}
	printf("%d\n",ans);
}
vector<int> f;
int ta[N],tb[N];
inline void solve2(){
	for(int i=0;i<=20;++i)bin[1<<i]=i;
	n+=ceil(log(m)/log(2));
	f.push_back(1);
	for(int x=0;x<=n;++x){
		cnt=0;
		for(int i=head[x];i;i=Next[i])a[++cnt]=P[i];
		if(cnt)build(1,1,cnt);
		else{
			t[1].clear();
			t[1].push_back(1);
		}
		int cnt2=f.size()-1;
		for(int i=0;i<=cnt;++i)ta[i]=t[1][i];
		for(int i=0;i<=cnt2;++i)tb[i]=f[i];
		int cnt3=cnt+cnt2+1;
		for(len=1;len<=cnt3;len<<=1);
		for(int i=cnt+1;i<len;++i)ta[i]=0;
		for(int i=cnt2+1;i<len;++i)tb[i]=0;
		NTT(ta,len,1);NTT(tb,len,1);
		for(int i=0;i<len;++i)ta[i]=(LL)ta[i]*tb[i]%mod;
		NTT(ta,len,-1);cnt3--;
		for(int i=0;i<=cnt3;++i)ans=(ans+(LL)ta[i]*i)%mod,tmp[i]=0;
		for(int i=0;i<=cnt3;++i)tmp[i>>1]=(tmp[i>>1]+ta[i])%mod;
		cnt3>>=1;f.clear();
		for(int i=0;i<=cnt3;++i)f.push_back(tmp[i]);
	}
	printf("%d\n",ans);
}
int main(){
	n=read();m=read();
	for(int i=1;i<=m;++i){
		int a=read(),x=read(),y=read();
		int p=(LL)x*ksm(y,mod-2)%mod;
		addmsg(a,p);
	}
	if(n==1)solve();
	else if(m<=3000){
		n+=ceil(log(m)/log(2));
		dp[0]=1;
		for(int x=0;x<=n;++x){
			for(int i=head[x];i;i=Next[i]){
				int p=P[i];
				for(int j=m;j;--j)
					dp[j]=((LL)dp[j]*(1-p+mod)+(LL)dp[j-1]*p)%mod;
				dp[0]=(LL)dp[0]*(1-p+mod)%mod;
			}
			for(int j=1;j<=m;++j)ans=(ans+(LL)j*dp[j])%mod;
			for(int j=0;j<=m;++j)tmp[j>>1]=(tmp[j>>1]+dp[j])%mod;
			for(int j=0;j<=m;++j)dp[j]=tmp[j],tmp[j]=0;
		}
		printf("%d\n",ans);
	}
	else solve2();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/stone41123/article/details/84567325