2020牛客暑期多校训练营(第三场) Operating on the Tree

原题
题目描述
此问题是由问题G(Operating on a Graph )启发的。 因此,您需要阅读它的声明才能解决此问题。
您将得到一棵具有n个顶点的树。 假设p是从0到n-1的排列。 我们定义函数f(p)如下:
假设给定的树是问题G的输入图,而p是输入运算符序列。 f(p)是满足条件的操作数:执行第i个操作时,至少有一个顶点属于o[i]组。
令S为从0到n-1的所有可能排列的集合。 请计算
在这里插入图片描述
样例
输入

3
4
0 1 2
4
0 1 1
2
0

输出

48
60
2

思路
根据题意,我们先把点分为两类:一类是对答案有贡献的点,也就是成功的点,我们称之为好点,标记为0。一类是对答案没有贡献的点,也就是失败的点,我们称之为坏点。在坏点中,还可以分为两类,一类是已经失败的的点,标记为1。一类是即将失败的点,标记为2。
而且,我们还可以出以下几个规律:
1 1、 坏节点和即将失败的点一定是在好节点的旁边
2 2、 好节点和好节点不可能相邻
这样就可以用树形DP来做了
dp [ [ i ] ] [ [ j ] ] [ [ k ] ] 分别表示子树树节点,根节点的类别,子树里有多少点比根节点大。
具体的转移方程请看代码。
代码

#include <bits/stdc++.h>
const int N=2010,mod=998244353;
std::vector<int> e[N];
int comb[N][N],dp[N][3][N],dp1[N][3][N],tmp[3][N],tmp1[3][N],sz[N],n,t,i,j,x;
void add(int &u, int v){u+=v;u-=u>=mod?mod:0;}
void dfs(int u)
{
    sz[u]=dp[u][0][0]=dp[u][2][0]=1;
    for(int k=0;k<e[u].size();k++)
	{
    	int v=e[u][k];dfs(v);
        for(i=0;i<3;i++)for(j=1;j<sz[v];j++)add(dp[v][i][j],dp[v][i][j-1]),add(dp1[v][i][j],dp1[v][i][j-1]);
        for(i=0;i<sz[u];i++)
            for(j=0;j<=sz[v];j++)
			{
                int coe=1ll*comb[i+j][i]*comb[sz[u]-1-i+sz[v]-j][sz[v]-j]%mod;
                for(int t1=0;t1<3;t1++)
                    for(int t2=0;t2<3;t2++)
					{
                        int cnt=j?dp[v][t2][j-1]:0;
                        int cnt1=j?dp1[v][t2][j-1]:0;
                        int coe1=1ll*coe*dp[u][t1][i]%mod*cnt%mod;
                        int base=coe*(1ll*dp[u][t1][i]*cnt1%mod+1ll*dp1[u][t1][i]*cnt%mod)%mod;
						if(t1==0){if(t2==1)add(tmp[0][i+j],coe1),add(tmp1[0][i+j],base);}
                        else if(t1==1){if(t2==0||t2==1)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);} 
                        else if(t1==2)
						{
                            if(t2==0)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);
                            else if(t2==1)add(tmp[2][i+j],coe1),add(tmp1[2][i+j],base);
                        }
                        cnt=dp[v][t2][sz[v]-1]-cnt;
                        cnt+=cnt<0?mod:0;
                        cnt1=dp1[v][t2][sz[v]-1]-cnt1;
                        cnt1+=cnt1<0?mod:0;
                        coe1=1ll*coe*dp[u][t1][i]%mod*cnt%mod;
                        base=coe*(1ll*dp[u][t1][i]*cnt1%mod+1ll*dp1[u][t1][i]*cnt%mod)%mod;
                        if(t1==0){if(t2==1||t2==2)add(tmp[0][i+j],coe1),add(tmp1[0][i+j],base);}
                        else if(t1==1){if(t2==0||t2==1)add(tmp[1][i+j],coe1),add(tmp1[1][i+j],base);}
                        else if(t1==2){if(t2==0||t2==1)add(tmp[2][i+j],coe1),add(tmp1[2][i+j],base);}
                    }
            }
        sz[u]+=sz[v];
        for(i=0;i<sz[u];i++)for(j=0;j<3;j++)dp[u][j][i]=tmp[j][i],dp1[u][j][i]=tmp1[j][i],tmp[j][i]=0,tmp1[j][i]=0;
    }
    for(i=0;i<sz[u];i++)add(dp1[u][0][i], dp[u][0][i]);
}
void solve()
{
    scanf("%d",&n);
    for(i=0;i<=n;i++)e[i].clear(),memset(dp[i],0,sizeof(dp[i])), memset(dp1[i],0,sizeof(dp1[i]));
    for(i=2;i<=n;i++)scanf("%d",&x),e[++x].push_back(i);
    dfs(1);int ans=0;
    for(i=0;i<n;i++) add(ans,dp1[1][0][i]),add(ans,dp1[1][1][i]);
    printf("%d\n",ans);
}
int main()
{
    for(i=0;i<N;i++){comb[i][0]=1;for(j=1;j<=i;j++)comb[i][j]=(comb[i-1][j-1]+comb[i-1][j])%mod;}
    for(scanf("%d",&t);t--;)solve();
}

如果代码还是看不懂,可以看一下这边有注释的代码。这里还有几个细节:
1 1、 用杨辉三角算组合数可以省时间
2 2、 把取余转化为减法可以节省时间
出题人一般不会卡这一些时间

END

猜你喜欢

转载自blog.csdn.net/bbbll123/article/details/107546922