联考20200525 T3 「雅礼集训 2018 Day7」C

题目传送门

分析:
考虑每个点对答案的贡献,答案是访问次数乘以期望距离
对于每一个点,从它出发跳一步的路径长度期望为该点到其他点距离之和除以\(n\)
这个可以两次遍历求出
关键要求出访问次数
\(f[i][0/1]\)表示图中有\(i\)\(1\)时权值为\(0/1\)的点的期望访问次数
列出方程:

\(f[i][0]=\frac{1}{n}+\frac{i}{n}f[i-1][0]+\frac{n-i-1}{n}f[i+1][0]+\frac{1}{n}f[i+1][1]\)
\(f[i][1]=\frac{1}{n}+\frac{i-1}{n}f[i-1][1]+\frac{1}{n}f[i-1][0]+\frac{n-i}{n}f[i+1][1]\)

注意一下边界特殊情况
如果我们知道\(f[0][0]\)\(f[1][0]\),尝试往前递推,首先通过第二个式子推出\(f[i][1]\),然后用第一个式子推出\(f[i][0]\)
\(f[0][0]=x,f[1][0]=y\),每个\(f\)值都可以表示为\(ax+by+c\),推到\(n-1\)的时候就可以变成一个二元一次方程,直接求解即可
复杂度\(O(n)\)

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<queue>
#include<algorithm>

#define maxn 200005
#define MOD 1000000007

using namespace std;

inline int getint()
{
	int num=0,flag=1;char c;
	while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
	while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
	return num*flag;
}

int n,cur;
int fir[maxn],nxt[maxn],to[maxn],cnt;
struct node{
	long long a,b,c;
	friend node operator +(node x,node y)
	{return (node){(x.a+y.a)%MOD,(x.b+y.b)%MOD,(x.c+y.c)%MOD};}
	friend node operator -(node x,node y)
	{return (node){(x.a-y.a+MOD)%MOD,(x.b-y.b+MOD)%MOD,(x.c-y.c+MOD)%MOD};}
	friend node operator +(node x,long long y)
	{return (node){x.a,x.b,(x.c+y)%MOD};}
	friend node operator -(node x,long long y)
	{return (node){x.a,x.b,(x.c-y+MOD)%MOD};}
	friend node operator *(node x,long long y)
	{return (node){x.a*y%MOD,x.b*y%MOD,x.c*y%MOD};}
}f[maxn][2];
long long dis[maxn],sz[maxn],ans;
char ch[maxn];

inline void newnode(int u,int v)
{to[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt;}
inline long long ksm(long long num,long long k)
{
	long long ret=1;
	for(;k;k>>=1,num=num*num%MOD)if(k&1)ret=ret*num%MOD;
	return ret;
}

inline void dfs1(int u,int fa,int dpt)
{
	sz[u]=1;
	for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa)
	{
		dis[1]+=dpt+1,dfs1(to[i],u,dpt+1),sz[u]+=sz[to[i]];
	}
}
inline void dfs2(int u,int fa)
{
	for(int i=fir[u];i;i=nxt[i])if(to[i]!=fa)
		dis[to[i]]=dis[u]-sz[to[i]]+n-sz[to[i]],dfs2(to[i],u);
}

int main()
{
	n=getint();scanf("%s",ch+1);
	long long Inv=ksm(n,MOD-2);
	for(int i=1;i<=n;i++)if(ch[i]=='1')cur++;
	for(int i=2;i<=n;i++)
	{
		int u=getint();
		newnode(i,u),newnode(u,i);
	}
	dfs1(1,1,0),dfs2(1,1);
	
	for(int i=1;i<=n;i++)dis[i]=dis[i]*Inv%MOD;
	f[1][0].a=1,f[1][1].b=1;
	for(int i=2;i<n;i++)
	{
		f[i][1]=f[i-1][1]*n-f[i-2][1]*(i-2)-f[i-2][0];
		if(i>2)f[i][1]=f[i][1]-1ll;
		f[i][1]=f[i][1]*ksm(n-i+1,MOD-2);
		f[i][0]=f[i-1][0]*n-f[i-2][0]*(i-1)-f[i][1];
		f[i][0]=f[i][0]-1ll;
		f[i][0]=f[i][0]*ksm(n-i,MOD-2);
	}
	node P1=f[n-1][0],P2=f[n-1][1];P1.b=(P1.b-1+MOD)%MOD,P2.a=(P2.a-1+MOD)%MOD;
	long long X=(P2.b*P1.c%MOD-P1.b*P2.c%MOD+MOD)%MOD;
	long long tmp=(P2.a*P1.b%MOD-P1.a*P2.b%MOD+MOD)%MOD;
	X=X*ksm(tmp,MOD-2)%MOD;
	long long Y=((MOD-P1.a*X%MOD)*ksm(P1.b,MOD-2)%MOD-P1.c*ksm(P1.b,MOD-2)%MOD+MOD)%MOD;
	long long f0=(f[cur][0].a*X%MOD+f[cur][0].b*Y%MOD+f[cur][0].c+Inv)%MOD;
	long long f1=(f[cur][1].a*X%MOD+f[cur][1].b*Y%MOD+f[cur][1].c+Inv)%MOD;
	for(int i=1;i<=n;i++)
	{
		if(ch[i]=='1')ans=(ans+f1*dis[i]%MOD)%MOD;
		else ans=(ans+f0*dis[i]%MOD)%MOD;
	}
	printf("%lld\n",ans);
}

猜你喜欢

转载自www.cnblogs.com/Darknesses/p/12961220.html