洛谷P5289 [十二省联考2019]皮配(01背包)

啊啊啊边界判错了搞死我了QAQ

这题是一个想起来很休闲写起来很恶心的背包

对于\(k=0\)的情况,可以发现选阵营和选派系是独立的,对选城市选阵营和学校选派系分别跑一遍01背包就行了

对于\(k>0\)的情况,设\(f[i][0/1][j][k]\)表示对于第\(i\)个有限制的学校,该学校选择\(0/1\)阵营时,\(C0\)阵营有\(j\)人,\(D0\)派系有\(k\)人的方案数

转移要分类讨论,有点麻烦,看代码吧

// luogu-judger-enable-o2
#include <bits/stdc++.h>
#define N 1010
#define M 2505
#define R register
#define rep(i,x,y) for(i=x;i<=y;++i)
#define des(i,x,y) for(i=x;i>=y;--i) 
#define mod 998244353
using namespace std;

int f[2][2][M][M],pref0[M],preg0[M],f0[M],g0[M];
int C0,C1,D0,D1,c0[N],tmp2[N],sum2=0;
int cnts[2],cntc[2],n,c;
bool visct[N]; 

inline int add(int x,int y){ return (x+y>=mod)?x+y-mod:x+y; }
inline int sub(int x,int y){ return (x-y<0)?x-y+mod:x-y; }
struct hhw{ int b,s,id,lim; } tmp[N],a0[N],ak[N];
inline void rd(int &x){
    char c=getchar();R int y=0;
    while(c<'0'||c>'9') c=getchar();
    while(c>='0' && c<='9') y=y*10+c-'0',c=getchar();
    x=y;
}
inline void solve0(){
    R int i,j;
    f0[0]=pref0[0]=1;
    rep(i,1,cntc[0])
        des(j,C0-c0[i],0)
            f0[j+c0[i]]=add(f0[j+c0[i]],f0[j]);
    rep(i,1,C0) pref0[i]=add(pref0[i-1],f0[i]);
    g0[0]=preg0[0]=1;
    rep(i,1,cnts[0])
        des(j,D0-a0[i].s,0)
            g0[j+a0[i].s]=add(g0[j+a0[i].s],g0[j]);     
    rep(i,1,D0) preg0[i]=add(preg0[i-1],g0[i]); 
}

inline bool cmp(const hhw &x,const hhw &y){ return x.b<y.b; }

inline void addf(int a,int b,int c,int d,int x){
    if(c>C0||d>D0) return;
    f[a][b][c][d]=add(f[a][b][c][d],x);
}

inline void solvek(){
    f[0][0][0][0]=1;
    R int i,j,k,o,pos=1;
    sort(ak+1,ak+cnts[1]+1,cmp);
    rep(i,1,cnts[1]){
        rep(o,0,1)
            des(j,C0,0)
                des(k,sum2,0){//不能直接写D0!!!要把有限制的学校的人数统计出来!!! 
                    f[pos][o][j][k]=0;
                    if(o==0){
                        if(ak[i].lim^0){
                            if(ak[i].b==ak[i-1].b)
                                addf(pos,o,j,k+ak[i].s,f[pos^1][o][j][k]);                              
                            else{
                                addf(pos,o,j+tmp2[ak[i].b],k+ak[i].s,f[pos^1][o][j][k]);
                                addf(pos,o,j+tmp2[ak[i].b],k+ak[i].s,f[pos^1][o^1][j][k]);                              
                            }
                        }
                        if(ak[i].lim^1){
                            if(ak[i].b==ak[i-1].b)
                                addf(pos,o,j,k,f[pos^1][o][j][k]);                              
                            else{
                                addf(pos,o,j+tmp2[ak[i].b],k,f[pos^1][o][j][k]);
                                addf(pos,o,j+tmp2[ak[i].b],k,f[pos^1][o^1][j][k]);                              
                            }                           
                        }
                    } else{
                        if(ak[i].lim^2){
                            if(ak[i].b==ak[i-1].b)
                                addf(pos,o,j,k+ak[i].s,f[pos^1][o][j][k]);                              
                            else{
                                addf(pos,o,j,k+ak[i].s,f[pos^1][o][j][k]);
                                addf(pos,o,j,k+ak[i].s,f[pos^1][o^1][j][k]);                                
                            }
                        }
                        if(ak[i].lim^3){
                            if(ak[i].b==ak[i-1].b)
                                addf(pos,o,j,k,f[pos^1][o][j][k]);                              
                            else{
                                addf(pos,o,j,k,f[pos^1][o][j][k]);
                                addf(pos,o,j,k,f[pos^1][o^1][j][k]);                                
                            }                           
                        }                       
                    }           
                }
        pos^=1;             
    }
}


inline int qwq(int x,int y){ return y>=0?sub(pref0[x],pref0[y]):pref0[x]; }
inline int owo(int x,int y){ return y>=0?sub(preg0[x],preg0[y]):preg0[x]; }

inline void solve(){
    R int i,x,y,k,j,sum=0,ans=0;
    rd(n),rd(c);
    rd(C0),rd(C1),rd(D0),rd(D1);
    rep(i,1,n){
        rd(tmp[i].b),rd(tmp[i].s); 
        tmp[i].id=i,tmp[i].lim=-1;
        sum+=tmp[i].s; 
        tmp2[tmp[i].b]+=tmp[i].s;
    }
    rd(k);
    rep(i,1,k){
        rd(x),rd(y),tmp[x].lim=y;
        visct[tmp[x].b]=1;sum2+=tmp[x].s;
    }
    rep(i,1,n){
        if(tmp[i].lim==-1) a0[++cnts[0]]=tmp[i];
        else ak[++cnts[1]]=tmp[i];
    }
    rep(i,1,c){
        if(!visct[i] && tmp2[i]) c0[++cntc[0]]=tmp2[i];
        else cntc[1]++;
    }
    if(sum-C0>C1 || sum-D0>D1){
        printf("0\n");
        return;
    }
    solve0(),solvek(); 
    rep(i,0,C0)
        rep(j,0,D0){
            if(sum-C1-i<=C0-i && sum-D1-j<=D0-j){
                int x=add(f[cnts[1]&1][0][i][j],f[cnts[1]&1][1][i][j]);
                ans=add(ans,1ll*x*qwq(C0-i,sum-C1-i-1)%mod*owo(D0-j,sum-D1-j-1)%mod);               
            }
            f[0][0][i][j]=f[0][1][i][j]=f[1][0][i][j]=f[1][1][i][j]=0; 
        }
    printf("%d\n",ans);
    rep(i,1,c) tmp2[i]=visct[i]=0;
    cntc[1]=cntc[0]=cnts[1]=cnts[0]=sum=sum2=0;
    memset(f0,0,sizeof(int)*(C0+1));
    memset(g0,0,sizeof(int)*(D0+1));
}

int main(){     
    R int t;rd(t);
    while(t--) solve();
}

猜你喜欢

转载自www.cnblogs.com/PsychicBoom/p/10719552.html
今日推荐