Link
Difficulty
算法难度7,思维难度6,代码难度6
Description
一个初始为 的二进制数,有 次操作。
第 次操作是将这个二进制数加上 。这个操作以 的概率执行。
如果某次操作执行了并且修改了二进制数的 位,那么它会带来 的代价。
问代价和的期望,答案对 取模。
Solution
这题有两类部分分,一类是 ,另一类是 。
首先考虑第一类,我们需要一个平方级别的算法。
考虑每位的贡献,我们可以设计一个 代表第 位改变 次的概率。
转移的话,首先是加入一个 的概率 的修改:
还有从 位转移到 位所造成的影响:
然后对每一位分别统计贡献即可,时间复杂度 。
接着来考虑 全为 的情况,我们发现一个操作会有 的概率贡献 的改变次数, 的概率不贡献,那么就可以表示成 这样一个多项式。
那么我们只需要一次分治NTT算出来这个式子即可 ,然后就可以逐位讨论贡献了,套用上面的 式子,由于最多影响到 级别的位数,那么总复杂度 。
我们考虑综合这两个做法来推出正解。
我们考虑直接用多项式来记录当前这一位的 状态。
每次加入一位上的修改的时候,我们可以直接用分治NTT,然后再用多项式乘法将原来的状态乘上这一次的变化量,这样子做复杂度很容易证明,由于每一个修改最多影响到 位,那么总复杂度仍然是 。
代码中有完整的部分分做法。
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define LL long long
using namespace std;
inline int read(){
int x=0,f=1;char ch=' ';
while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return f==1?x:-x;
}
const int N=6e5+5,mod=998244353,g=3,gi=(mod+1)/g;
inline int ksm(int a,int n){
int ans=1;
while(n){
if(n&1)ans=(LL)ans*a%mod;
a=(LL)a*a%mod;
n>>=1;
}
return ans;
}
int n,m,ans,tot;
int dp[N],tmp[N],head[N],Next[N],P[N];
inline void addmsg(int x,int p){
P[++tot]=p;
Next[tot]=head[x];
head[x]=tot;
}
int R[N],bin[1<<21];
inline void NTT(int *a,int n,int f){
int L=bin[n];
for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
for(int i=0;i<n;++i)if(i<R[i])swap(a[i],a[R[i]]);
for(int i=1;i<n;i<<=1){
int wn=ksm(f==1?g:gi,(mod-1)/(i<<1));
for(int j=0;j<n;j+=(i<<1)){
int w=1;
for(int k=0;k<i;++k,w=(LL)w*wn%mod){
int x=a[j+k],y=(LL)w*a[j+k+i]%mod;
a[j+k]=(x+y)%mod;
a[j+k+i]=(x-y+mod)%mod;
}
}
}
if(f==-1){
int num=ksm(n,mod-2);
for(int i=0;i<n;++i)a[i]=(LL)a[i]*num%mod;
}
}
int a[N],cnt,len,b[N],c[N];
vector<int> t[N<<1];
inline void build(int rt,int l,int r){
t[rt].clear();
if(l==r){
t[rt].push_back((1-a[l]+mod)%mod);
t[rt].push_back(a[l]);
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
int len1=t[rt<<1].size(),len2=t[rt<<1|1].size();
for(int i=0;i<len1;++i)b[i]=t[rt<<1][i];
for(int i=0;i<len2;++i)c[i]=t[rt<<1|1][i];
int n=r-l+2;
for(len=1;len<=n;len<<=1);
for(int i=len1;i<len;++i)b[i]=0;
for(int i=len2;i<len;++i)c[i]=0;
NTT(b,len,1);NTT(c,len,1);
for(int i=0;i<len;++i)b[i]=(LL)b[i]*c[i]%mod;
NTT(b,len,-1);
for(int i=0;i<n;++i)t[rt].push_back(b[i]);
}
inline void solve(){
for(int i=0;i<=20;++i)bin[1<<i]=i;
for(int i=head[0];i;i=Next[i])a[++cnt]=P[i];
build(1,1,cnt);
n+=ceil(log(m)/log(2));
for(int i=1;i<=cnt;++i)a[i]=t[1][i];
for(int x=0;x<=n;++x){
for(int i=0;i<=cnt;++i)ans=(ans+(LL)i*a[i])%mod;
for(int i=1;i<=cnt;++i)tmp[i>>1]=(tmp[i>>1]+a[i])%mod;
for(int i=0;i<=cnt;++i)a[i]=tmp[i],tmp[i]=0;
cnt>>=1;
}
printf("%d\n",ans);
}
vector<int> f;
int ta[N],tb[N];
inline void solve2(){
for(int i=0;i<=20;++i)bin[1<<i]=i;
n+=ceil(log(m)/log(2));
f.push_back(1);
for(int x=0;x<=n;++x){
cnt=0;
for(int i=head[x];i;i=Next[i])a[++cnt]=P[i];
if(cnt)build(1,1,cnt);
else{
t[1].clear();
t[1].push_back(1);
}
int cnt2=f.size()-1;
for(int i=0;i<=cnt;++i)ta[i]=t[1][i];
for(int i=0;i<=cnt2;++i)tb[i]=f[i];
int cnt3=cnt+cnt2+1;
for(len=1;len<=cnt3;len<<=1);
for(int i=cnt+1;i<len;++i)ta[i]=0;
for(int i=cnt2+1;i<len;++i)tb[i]=0;
NTT(ta,len,1);NTT(tb,len,1);
for(int i=0;i<len;++i)ta[i]=(LL)ta[i]*tb[i]%mod;
NTT(ta,len,-1);cnt3--;
for(int i=0;i<=cnt3;++i)ans=(ans+(LL)ta[i]*i)%mod,tmp[i]=0;
for(int i=0;i<=cnt3;++i)tmp[i>>1]=(tmp[i>>1]+ta[i])%mod;
cnt3>>=1;f.clear();
for(int i=0;i<=cnt3;++i)f.push_back(tmp[i]);
}
printf("%d\n",ans);
}
int main(){
n=read();m=read();
for(int i=1;i<=m;++i){
int a=read(),x=read(),y=read();
int p=(LL)x*ksm(y,mod-2)%mod;
addmsg(a,p);
}
if(n==1)solve();
else if(m<=3000){
n+=ceil(log(m)/log(2));
dp[0]=1;
for(int x=0;x<=n;++x){
for(int i=head[x];i;i=Next[i]){
int p=P[i];
for(int j=m;j;--j)
dp[j]=((LL)dp[j]*(1-p+mod)+(LL)dp[j-1]*p)%mod;
dp[0]=(LL)dp[0]*(1-p+mod)%mod;
}
for(int j=1;j<=m;++j)ans=(ans+(LL)j*dp[j])%mod;
for(int j=0;j<=m;++j)tmp[j>>1]=(tmp[j>>1]+dp[j])%mod;
for(int j=0;j<=m;++j)dp[j]=tmp[j],tmp[j]=0;
}
printf("%d\n",ans);
}
else solve2();
return 0;
}