原题
题目描述
此问题是由问题G(Operating on a Graph )启发的。 因此,您需要阅读它的声明才能解决此问题。
您将得到一棵具有n个顶点的树。 假设p是从0到n-1的排列。 我们定义函数f(p)如下:
假设给定的树是问题G的输入图,而p是输入运算符序列。 f(p)是满足条件的操作数:执行第i个操作时,至少有一个顶点属于o[i]组。
令S为从0到n-1的所有可能排列的集合。 请计算
样例
输入
3
4
0 1 2
4
0 1 1
2
0
输出
48
60
2
思路
根据题意,我们先把点分为两类:一类是对答案有贡献的点,也就是成功的点,我们称之为好点,标记为0。一类是对答案没有贡献的点,也就是失败的点,我们称之为坏点。在坏点中,还可以分为两类,一类是已经失败的的点,标记为1。一类是即将失败的点,标记为2。
而且,我们还可以出以下几个规律:
坏节点和即将失败的点一定是在好节点的旁边
好节点和好节点不可能相邻
这样就可以用树形DP来做了
dp
i
j
k
分别表示子树树节点,根节点的类别,子树里有多少点比根节点大。
具体的转移方程请看代码。
代码
#include <bits/stdc++.h>
const int N=2010,mod=998244353;
std::vector<int> e[N];
int comb[N][N],dp[N][3][N],dp1[N][3][N],tmp[3][N],tmp1[3][N],sz[N],n,t,i,j,x;
void add(int &u, int v){u+=v;u-=u>=mod?mod:0;}
void dfs(int u)
{
sz[u]=dp[u][0][0]=dp[u][2][0]=1;
for(int k=0;k<e[u].size();k++)
{
int v=e[u][k];dfs(v);
for(i=0;i<3;i++)for(j=1;j<sz[v];j++)add(dp[v][i][j],dp[v][i][j-1]),add(dp1[v][i][j],dp1[v][i][j-1]);
for(i=0;i<sz[u];i++)
for(j=0;j<=sz[v];j++)
{
int coe=1ll*comb[i+j][i]*comb[sz[u]-1-i+sz[v]-j][sz[v]-j]%mod;
for(int t1=0;t1<3;t1++)
for(int t2=0;t2<3;t2++)
{
int cnt=j?dp[v][t2][j-1]:0;
int cnt1=j?dp1[v][t2][j-1]:0;
int coe1=1ll*coe*dp[u][t1][i]%mod*cnt%mod;
int base=coe*(1ll*dp[u][t1][i]*cnt1%mod+1ll*dp1[u][t1][i]*cnt%mod)%mod;
if(t1==0){if(t2==1)add(tmp[0][i+j],coe1),add(tmp1[0][i+j],base);}
else if(t1==1){if(t2==0||t2==1)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);}
else if(t1==2)
{
if(t2==0)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);
else if(t2==1)add(tmp[2][i+j],coe1),add(tmp1[2][i+j],base);
}
cnt=dp[v][t2][sz[v]-1]-cnt;
cnt+=cnt<0?mod:0;
cnt1=dp1[v][t2][sz[v]-1]-cnt1;
cnt1+=cnt1<0?mod:0;
coe1=1ll*coe*dp[u][t1][i]%mod*cnt%mod;
base=coe*(1ll*dp[u][t1][i]*cnt1%mod+1ll*dp1[u][t1][i]*cnt%mod)%mod;
if(t1==0){if(t2==1||t2==2)add(tmp[0][i+j],coe1),add(tmp1[0][i+j],base);}
else if(t1==1){if(t2==0||t2==1)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);}
else if(t1==2){if(t2==0||t2==1)add(tmp[2][i+j],coe1),add(tmp1[2][i+j],base);}
}
}
sz[u]+=sz[v];
for(i=0;i<sz[u];i++)for(j=0;j<3;j++)dp[u][j][i]=tmp[j][i],dp1[u][j][i]=tmp1[j][i],tmp[j][i]=0,tmp1[j][i]=0;
}
for(i=0;i<sz[u];i++)add(dp1[u][0][i], dp[u][0][i]);
}
void solve()
{
scanf("%d",&n);
for(i=0;i<=n;i++)e[i].clear(),memset(dp[i],0,sizeof(dp[i])), memset(dp1[i],0,sizeof(dp1[i]));
for(i=2;i<=n;i++)scanf("%d",&x),e[++x].push_back(i);
dfs(1);int ans=0;
for(i=0;i<n;i++) add(ans,dp1[1][0][i]),add(ans,dp1[1][1][i]);
printf("%d\n",ans);
}
int main()
{
for(i=0;i<N;i++){comb[i][0]=1;for(j=1;j<=i;j++)comb[i][j]=(comb[i-1][j-1]+comb[i-1][j])%mod;}
for(scanf("%d",&t);t--;)solve();
}
如果代码还是看不懂,可以看一下这边有注释的代码。这里还有几个细节:
用杨辉三角算组合数可以省时间
把取余转化为减法可以节省时间
出题人一般不会卡这一些时间