【bzoj 3622】已经没有什么好害怕的了

题目

看到这个数据范围就发现我们需要一个\(O(n^2)\)的做法了,那大概率是\(dp\)

看到恰好\(k\)个我们就知道这基本是个容斥了

首先解方程发现我们需要使得\(a>b\)的恰好有\(\frac{n+k}{2}\)

如果不整除我们直接输出\(0\)就好了

之后开始使用套路,直接广义容斥

\[ans=\sum_{i=k}^n(-1)^{i-k}\binom{i}{k}g_i\]

\(g_i\)表配出至少\(i\)\(a>b\)的情况

显然我们现在需要一个\(dp\)来算一下\(g\)

首先发现两个数组是没有顺序的,所以先习惯性排个序

\(dp_{i,j}\)表示从\(a\)数组的前\(i\)个数中,已经配出\(j\)\(a>b\)的方案数

边界\(dp_{0,0}=1\)

我们排序的作用这个时候就体现出来了,我们设\(d_i\)表示满足\(b_j<a_i\)的最大的\(j\)

由于\(a,b\)两个数组都是有序的,我们知道\(d_i\)肯定是单调不降的

于是有这样的方程

\[dp_{i,j}=dp_{i-1,j}+max(d_i-(j-1),0)dp_{i-1,j-1}\]

就是考虑对于第\(i\)个数满足\(a>b\)的只有\(d_i\)个,减去和前\(i-1\)个匹配的\(j-1\)个,剩下的我们随便找出一个来匹配就好了

之后\(g_i=dp_{n,i}(n-i)!\),就是让没有满足\(a<b\)的那些随便匹配一下就好

代码

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define pt putchar(1),puts("")
const int maxn=2e3+5;
const int mod=1e9+9;
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
int n,k;
int dp[maxn][maxn];
int a[maxn],b[maxn];
int fac[maxn],inv[maxn];
inline LL ksm(LL a,int b) {
    LL S=1;
    while(b) {if(b&1) S=S*a%mod;b>>=1;a=a*a%mod;}
    return S;
} 
inline int C(int n,int m) {
    if(m>n) return 0;
    return 1ll*fac[n]*inv[n-m]%mod*inv[m]%mod;
}
int main() {
    n=read();k=read();
    for(re int i=1;i<=n;i++) a[i]=read();
    for(re int i=1;i<=n;i++) b[i]=read();
    if((n+k)&1) {puts("0");return 0;}
    std::sort(a+1,a+n+1),std::sort(b+1,b+n+1);
    fac[0]=1;
    for(re int i=1;i<=n;i++) fac[i]=(1ll*i*fac[i-1])%mod;
    inv[n]=ksm(fac[n],mod-2);
    for(re int i=n-1;i>=0;--i) inv[i]=(1ll*(i+1)*inv[i+1])%mod;
    dp[0][0]=1;
    for(re int i=1;i<=n;i++) {
        int cnt=0;
        for(re int j=1;j<=n;j++)
            cnt+=(a[i]>b[j]);
        for(re int j=0;j<=i;j++)
            dp[i][j]=dp[i-1][j];
        for(re int j=1;j<=i;j++)
            dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-1]*max(cnt-j+1,0)%mod)%mod;
    }
    k=(n+k)/2;
    LL ans=0;
    for(re int i=k;i<=n;i++)
    if((i-k)&1) ans=(ans-1ll*C(i,k)*dp[n][i]%mod*fac[n-i]%mod+mod)%mod;
        else ans=(ans+1ll*C(i,k)*dp[n][i]%mod*fac[n-i]%mod)%mod;
    printf("%d\n",(int)ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/asuldb/p/10632601.html