[nowcoder5668J]Operating on the Tree

考虑令$a_{i}$为i的位置,$p_{i}=0/1$表示第i个点的贡献,那么$p_{x}=0$当且仅当存在与其相邻的点$y$满足$a_{y}<a_{x}$且$p_{y}=1$
树形dp,定义状态$g[k][j][0/1/2]$表示以$k$为根的子树中选择了j个点,$p_{k}=1$或$p_{k}=0$或还没有参与,$f[k][j][0/1/2]$表示这样的$cost$之和
暴力转移即可(转移方程详见代码),注意:1.要乘上组合数表示不同子树之间的顺序;2.要另开一个数组来讨论;3.讨论时要对儿子和根的位置关系分类
  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 2005
  4 #define mod 998244353
  5 struct ji{
  6     int nex,to;
  7 }edge[N];
  8 int E,t,n,x,fac[N],inv[N],head[N],sz[N],g[N][N][3],f[N][N][3];
  9 int c(int n,int m){
 10     return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod;
 11 }
 12 void add(int x,int y){
 13     edge[E].nex=head[x];
 14     edge[E].to=y;
 15     head[x]=E++;
 16 }
 17 void dfs(int k){
 18     sz[k]=g[k][0][0]=g[k][0][2]=1;
 19     for(int i=head[k];i!=-1;i=edge[i].nex){
 20         int u=edge[i].to;
 21         dfs(u);
 22         for(int j=1;j<sz[u];j++)
 23             for(int p=0;p<3;p++){
 24                 g[u][j][p]=(g[u][j][p]+g[u][j-1][p])%mod;
 25                 f[u][j][p]=(f[u][j][p]+f[u][j-1][p])%mod;
 26             }
 27         for(int j=0;j<sz[k];j++)
 28             for(int jj=0;jj<=sz[u];jj++){
 29                 int C=1LL*c(j+jj,j)*c(sz[k]-j-1+sz[u]-jj,sz[u]-jj)%mod;
 30                 for(int p1=0;p1<3;p1++)
 31                     for(int p2=0;p2<3;p2++){
 32                         int s1=0,s2=0;
 33                         if (jj){
 34                             s1=1LL*g[u][jj-1][p2]*g[k][j][p1]%mod*C%mod;
 35                             s2=(1LL*g[u][jj-1][p2]*f[k][j][p1]+1LL*f[u][jj-1][p2]*g[k][j][p1])%mod*C%mod;
 36                         }
 37                         if ((p1==0)&&(p2==1)){
 38                             g[0][j+jj][0]=(g[0][j+jj][0]+s1)%mod;
 39                             f[0][j+jj][0]=(f[0][j+jj][0]+s2)%mod;
 40                         }
 41                         if ((p1==1)&&(p2<2)||(p1==2)&&(p2==0)){
 42                             g[0][j+jj][1]=(g[0][j+jj][1]+s1)%mod;
 43                             f[0][j+jj][1]=(f[0][j+jj][1]+s2)%mod;
 44                         }
 45                         if ((p1==2)&&(p2==1)){
 46                             g[0][j+jj][2]=(g[0][j+jj][2]+s1)%mod;
 47                             f[0][j+jj][2]=(f[0][j+jj][2]+s2)%mod;
 48                         }
 49                         if (!jj)s1=g[u][sz[u]-1][p2];
 50                         else s1=(g[u][sz[u]-1][p2]-g[u][jj-1][p2]+mod)%mod;
 51                         if (!jj)s2=f[u][sz[u]-1][p2];
 52                         else s2=(f[u][sz[u]-1][p2]-f[u][jj-1][p2]+mod)%mod;
 53                         s2=(1LL*s1*f[k][j][p1]+1LL*s2*g[k][j][p1])%mod*C%mod;
 54                         s1=1LL*s1*g[k][j][p1]%mod*C%mod;
 55                         if ((p1==0)&&(p2)){
 56                             g[0][j+jj][0]=(g[0][j+jj][0]+s1)%mod;
 57                             f[0][j+jj][0]=(f[0][j+jj][0]+s2)%mod;
 58                         }
 59                         if ((p1==1)&&(p2<2)){
 60                             g[0][j+jj][1]=(g[0][j+jj][1]+s1)%mod;
 61                             f[0][j+jj][1]=(f[0][j+jj][1]+s2)%mod;
 62                         }
 63                         if ((p1==2)&&(p2<2)){
 64                             g[0][j+jj][2]=(g[0][j+jj][2]+s1)%mod;
 65                             f[0][j+jj][2]=(f[0][j+jj][2]+s2)%mod;
 66                         }
 67                     }
 68             }
 69         sz[k]+=sz[u];
 70         for(int j=0;j<sz[k];j++)
 71             for(int p=0;p<3;p++){
 72                 g[k][j][p]=g[0][j][p];
 73                 f[k][j][p]=f[0][j][p];
 74                 g[0][j][p]=f[0][j][p]=0;
 75             }
 76     }
 77     for(int i=0;i<sz[k];i++)f[k][i][0]=(f[k][i][0]+g[k][i][0])%mod;
 78 }
 79 int main(){
 80     fac[0]=inv[0]=inv[1]=1;
 81     for(int i=1;i<N-4;i++)fac[i]=1LL*fac[i-1]*i%mod;
 82     for(int i=2;i<N-4;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
 83     for(int i=2;i<N-4;i++)inv[i]=1LL*inv[i-1]*inv[i]%mod;
 84     scanf("%d",&t);
 85     while (t--){
 86         scanf("%d",&n);
 87         E=0;
 88         memset(head,-1,4*(n+1));
 89         for(int i=1;i<=n;i++){
 90             memset(g[i],0,sizeof(g[i]));
 91             memset(f[i],0,sizeof(f[i]));
 92         }
 93         for(int i=2;i<=n;i++){
 94             scanf("%d",&x);
 95             add(x+1,i);
 96         }
 97         dfs(1);
 98         int ans=0;
 99         for(int i=0;i<n;i++)ans=(ans+0LL+f[1][i][0]+f[1][i][1])%mod;
100         printf("%d\n",ans);
101     }
102 }
View Code

猜你喜欢

转载自www.cnblogs.com/PYWBKTDA/p/13397751.html