[AGC019E]Shuffle and Swap

[AGC019E]Shuffle and Swap

题目大意:

给出两个长度为\(n(n\le10000)\)\(01\)\(A_{1\sim n}\)\(B_{1\sim n}\)。两个串均有\(k\)'1'。令\(a_{1\sim k}\)\(b_{1\sim k}\)分别表示\(A\)\(B\)中所有'1'出现的位置。将\(a\)\(b\)等概率随机排列,按\(1\sim k\)的顺序交换\(A_{a_i}\)\(A_{b_i}\)。令\(P\)表示操作完成后\(A\)\(B\)相等的概率,求\(P\times(k!)^2\)在模\(998244353\)意义下的值。

思路:

\(k\)'1'分为两类:\(S_1\)\(S_2\)\(S_1\)表示\(A\)\(B\)中均为'1',有\(s_1\)个;\(S_2\)表示\(A\)\(B\)中只有一个为'1',有\(s_2\)个。用\(f[i][j]\)表示\(S_1\)交换完\(i\)个,\(S_2\)交换完\(j\)个后的方案数。

显然,初始情况下,\(S_1\)可以交换\(2\)次,\(S_2\)可以交换\(1\)次。方便起见,假设我们的\(A\)\(B\)分别为:

110011
111100

对于\(i=0\)的初始情况,由于\(S_2\)之间任意配对都可以使得\(A=B\),因此有\((j!)^2\)种方案。

若我们交换一对\(S_2\),如\(4\)\(5\),那么交换后,可以交换\(1\)次的\(S_2\)减少了一对。有转移\(f[i][j]+=f[i][j-1]\times j^2\)

若我们交换一个\(S_1\)和一个\(S_2\),如\(2\)\(6\),那么交换后,原来那个\(S_2\)已经不能再交换了,而原来可以交换\(2\)次的\(S_1\)现在只能交换一次,可以算入\(S_2\)中,也就是相当于减少了一个\(S_1\)。有转移\(f[i][j]+=f[i-1][j]\times i\times j\)

最后统计答案时,对于\(S_1\)没有用完的情况,无论如何配对都能满足\(A=B\)。因此\(f[s_1-i][s_2]\)对答案的贡献为\(f[s_1-i][s_2]\times(i!)^2\times\binom{s_1}i\times\binom ki\)

时间复杂度\(\mathcal O(n^2)\)

源代码:

#include<cstdio>
#include<cstring>
using int64=long long;
constexpr int N=10001,mod=998244353;
char a[N],b[N];
int fact[N],factinv[N],f[N][N];
void exgcd(const int &a,const int &b,int &x,int &y) {
    if(!b) {
        x=1,y=0;
        return;
    }
    exgcd(b,a%b,y,x);
    y-=a/b*x;
}
inline int inv(const int &x) {
    int ret,tmp;
    exgcd(x,mod,ret,tmp);
    return (ret%mod+mod)%mod;
}
inline int C(const int &n,const int &m) {
    if(n<m||n<0||m<0) return 0;
    return (int64)fact[n]*factinv[m]%mod*factinv[n-m]%mod;
}
int main() {
    scanf("%s%s",a,b);
    const int n=strlen(a);
    for(register int i=fact[0]=1;i<=n;i++) {
        fact[i]=(int64)fact[i-1]*i%mod;
    }
    factinv[n]=inv(fact[n]);
    for(register int i=n;i;i--) {
        factinv[i-1]=(int64)factinv[i]*i%mod;
    }
    int s1=0,s2=0;
    for(register int i=0;i<n;i++) {
        if(a[i]=='1'&&b[i]=='1') s1++;
        if(a[i]=='1'&&b[i]=='0') s2++;
    }
    for(register int i=0;i<=s2;i++) {
        f[0][i]=(int64)fact[i]*fact[i]%mod;
    }
    for(register int i=1;i<=s1;i++) {
        for(register int j=1;j<=s2;j++) {
            f[i][j]=((int64)f[i-1][j]*i%mod*j%mod+(int64)f[i][j-1]*j%mod*j%mod)%mod;
        }
    }
    int ans=0;
    for(register int i=0;i<=s1;i++) {
        (ans+=(int64)f[s1-i][s2]*fact[i]%mod*fact[i]%mod*C(s1,i)%mod*C(s1+s2,i)%mod)%=mod;
    }
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/skylee03/p/9166950.html