-
这种求 “取到所有物品的期望时间” 的题一般都用 m i n − m a x min-max min−max容斥 解决:
设 t ( i , j ) t(i,j) t(i,j)为取到格子 ( i , j ) (i,j) (i,j)的期望时间,集合 S = ∪ c ( i , j ) = ′ ∗ ′ { t ( i , j ) } S=\cup_{c(i,j)='*'}\{t(i,j)\} S=∪c(i,j)=′∗′{ t(i,j)}
那么根据 m i n − m a x min-max min−max容斥有: max ( S ) = ∑ T ⊆ S , T ≠ ∅ ( − 1 ) ∣ T ∣ − 1 min ( T ) \max(S) = \sum_{T\subseteq S, T \neq \varnothing} (-1)^{|T|-1}\min(T) max(S)=T⊆S,T=∅∑(−1)∣T∣−1min(T) -
m i n ( T ) min(T) min(T)即为首次取到 T T T中的格子的期望时间(记为 E T E_T ET),考虑转成求概率:
设 P T P_{T} PT为取到 T T T中的格子的概率,由 E T = 1 + ( 1 − P T ) E T E_{T}=1+(1-P_T)E_T ET=1+(1−PT)ET解得 E T = 1 P T E_T=\frac{1}{P_T} ET=PT1
设有覆盖到 T T T中的格子的相邻对个数为 x x x,因为总共的相邻对个数为 m ( n − 1 ) + n ( m − 1 ) m(n-1)+n(m-1) m(n−1)+n(m−1),所以 P T = x m ( n − 1 ) + n ( m − 1 ) P_T=\frac{x}{m(n-1)+n(m-1)} PT=m(n−1)+n(m−1)x,得到 E T = m ( n − 1 ) + n ( m − 1 ) x E_T=\frac{m(n-1)+n(m-1)}{x} ET=xm(n−1)+n(m−1) -
发现子集 T T T数量很多,但是 x x x非常小,于是神奇地转换思路:
求出对于每个 x x x,有多少个子集 T T T满足恰有 x x x个相邻对有覆盖到 T T T中的点。
上插头dp,设 d p [ i ] [ j ] [ s ] [ x ] dp[i][j][s][x] dp[i][j][s][x]表示做到了 ( i , j ) (i,j) (i,j),当前状态为 s s s,有 x x x个相邻对。
我们dp的内容是系数和,如果选了一个格子,集合大小改变,要乘一个 − 1 -1 −1作为系数。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int mod=998244353;
int add(int a,int b){
return a+b>=mod?a+b-mod:a+b;}
int dec(int a,int b){
return a<b?a-b+mod:a-b;}
int mul(int a,int b){
return 1ll*a*b%mod;}
void Add(int &a,int b){
a=add(a,b);}
int ksm(int a,int b){
int res=1;
while(b){
if(b&1) res=mul(res,a);
b>>=1;a=mul(a,a);
}
return res;
}
int n,m;
char c[10][110];
int dp[2][1<<6][1250],ans;//N*M*2=1200
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%s",c[i]+1);
int tot=n*(m-1)+m*(n-1);
int cur=1,pre=0;
dp[cur][0][0]=dec(0,1);
int sta=(1<<n)-1;
for(int i=1;i<=m;++i){
for(int j=1;j<=n;++j){
swap(cur,pre);
memset(dp[cur],0,sizeof(dp[cur]));
for(int s=0;s<=sta;s++){
for(int k=0;k<=tot;k++){
if(dp[pre][s][k]){
int ss=s&(sta^(1<<(j-1)));
Add(dp[cur][ss][k],dp[pre][s][k]);
if(c[j][i]=='*'){
ss|=(1<<j-1);
int delta=(i!=1&&(s&(1<<j-1))==0)+(j!=1&&(s&(1<<j-2))==0)+(i<m)+(j<n);
Add(dp[cur][ss][k+delta],dec(0,dp[pre][s][k]));
}
}
}
}
}
}
for(int i=1;i<=tot;i++){
ll inv=ksm(i,mod-2);
for(int s=0;s<=sta;s++)
Add(ans,mul(dp[cur][s][i],inv));
}
ans=mul(ans,tot);
cout<<ans;
return 0;
}
参考文章:
https://www.cnblogs.com/ZH-comld/p/11014880.html
https://www.cnblogs.com/huyufeifei/p/10498429.html