一、题目
点此看题
二、解法
考虑每个灯的生成函数,那么答案的生成函数为它们的乘积,设
P=∑pi(这里用
e不是特别严谨,因为本题不需要指数型生成函数,但是为了方便表示无穷级数,使用
e可能会更加简洁):
F(x)=i=1∏n2ePpix+(−1)sie−Ppix但是还有一个问题,就是上面的生成函数并不会在第一次满足条件是就停止,我们考虑转满一圈,也就是操作会原状态的生成函数:
G(x)=i=1∏n2ePpix+e−Ppix求出
F,G直接暴力背包,时间复杂度
O(np),考虑答案的生成函数为
H(x),则满足
h(x)g(x)=f(x)(
h,g,f为
H,G,F的普通生成函数),考虑使用闭形式来表示它们之间的转化,它们之间满足这样的转化关系:
f(x)=i=−P∑P1−Pixai此时要求的答案是
h′(1),求导就是把每种概率乘上他的权值(步数),算出来就是期望,可以使用这个公式,(
g(x)1的求导可以当做复合函数求导,也就是加入一个辅助函数
f(x)=x1):
(g(x)f(x))′=(f(x)g(x)−1)′=g(x)2f′(x)g(x)−f(x)g′(x)发现这样算出来的
h是不收敛的,我们把
f,g同乘
∏1−pi,现在的问题在于解决
f(x)的导数,我们来推式子:
f(x)=∑aij=i∏1−pjx为了方便表示,我们设
gi(x)=ai∏j=ipjx,hi(x)=1−pix,我们想对
g求导:
gi′(x)=aij=i∏hj(x)′
gi′(x)=aik=i∑(j=i,j=k∏hj(x))×hk(x)′
gi′(x)=aik=i∑(j=i∏hj(x))×hk(x)hk(x)′......∗
gi′(x)=ai(j=i∏hj(x))k=i∑hk(x)hk(x)′
gi′(x)=ai(j=i∏1−pjx)k=i∑1−pkx−pk那么
f′(x)就算出来了,长成这样:
f′(x)=i∑ai(j=i∏1−pjx)k=i∑1−pkx−pk考虑求除
f′(1)的值,可以分类讨论,由于
x=1是,
1−pi很特殊,我们考虑
i是否等于
p
-
i=p,考虑到
j的枚举中会有
0这一项产生(
j=p时),那就都消完了?等等,看我标注了
*
号的柿子,我们为了提出相同的一项而乘上又除去了
1−pkx,所以在
k中产生的
0−1必须要与
0的一项对消,产生
−1,所以当
i=p时,
k只能等于
p,这样我们就可以对
g函数进行化简:
gi′(x)=−aij=i,j=p∏1−pjx
gi′(x)=−j=p∏1−pjx×1−pixai
-
i=p,直接带推出的
f′(x)的柿子去算,比较简单。
分类讨论之后,我们可以知道
f′(1)的值:
f′(1)=−(i=p∏1−pi)(i=p∑1−piai+api=p∑1−pipi)知道了
f′(1)的求导之后,
g′(1)也很容易(好像和推导的重名了,写到这里才发现,请谅解),我们可以带入上面对分式求导的柿子,就可以算出答案(设
ai为
f的系数,
bi为
g的系数,容易发现
ap=bp=2n1):
2ni=p∑1−pibi−ai写的我要吐了,贴上简洁的代码:
#include <cstdio>
#include <cstring>
#define int long long
const int M = 50005;
const int MOD = 998244353;
int read()
{
int x=0,flag=1;
char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,sum,inv,ans,s[105],tmp[2*M],f[2*M],g[2*M];
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
signed main()
{
n=read();
for(int i=1; i<=n; i++)
s[i]=read();
inv=qkpow(2,MOD-2);
f[M]=g[M]=1;
for(int i=1; i<=n; i++)
{
int x=read();
sum+=x;
memset(tmp,0,sizeof tmp);
for(int j=-sum; j+x<=sum; j++) tmp[j+x+M]=(tmp[j+x+M]+inv*f[j+M])%MOD;
for(int j=-sum+x; j<=sum; j++) tmp[j-x+M]=(tmp[j-x+M]+(s[i]?MOD-inv:inv)*f[j+M])%MOD;
memcpy(f,tmp,sizeof tmp);
memset(tmp,0,sizeof tmp);
for(int j=-sum; j+x<=sum; j++) tmp[j+x+M]=(tmp[j+x+M]+inv*g[j+M])%MOD;
for(int j=-sum+x; j<=sum; j++) tmp[j-x+M]=(tmp[j-x+M]+inv*g[j+M])%MOD;
memcpy(g,tmp,sizeof tmp);
}
for(int i=-sum; i<=sum; i++) ans=(ans+(g[i+M]-f[i+M]+MOD)*qkpow(sum-i,MOD-2))%MOD;
printf("%lld\n",ans*sum%MOD*qkpow(2,n)%MOD);
}