图片加载不出来的看这里
先简述一下题意:
给出一棵树,树上各点都有点权。
定义
为树上
的简单路径的序列。
设
位置的点点权为
,他到
的路径上的下一个点点权为
,以此类推……再设一下这条路径的长度为
。
然后定义
。
如果一条路径满足
那么这条边就是好边,反之如果不满足,这就是条坏边。
定义好三元组
为
都是好边或坏边的三元组。
求这棵树上有多少对好三元组?(三元组中的点可以重复)
保证y是质数
这题非常的绕……
所以,首先我们先不考虑好边坏边怎么在时限内求,先把怎么判断三元组弄清楚
思考一下,很显然针对三元组来说,只会有这八种情况(0是坏边,1是好边)
首先先考虑直接判断好三元组,你会发现,我们必须要知道三条边的信息才能判断,反正我比较菜这个东西我是想不到
的做法,所以要在思考一下别的方法,我们再看坏三元组的特征
坏三元组和好三元组的区别显然是有点连着一条好边和一条坏边,如下图红的点
那么统计一下,假设好的入边为 ,坏的为 ,同样,假设好的出边为 ,坏的为 ,在一番数数之后,我们得出图中各类点连得边的数量为
in0 out1 2
in1 out0 2
in1 in0 4
out1 out0 4
因为有两个点,所以每个点的贡献为1/2
即一个点会产生
个坏三元组
因为一个坏三元组里有两个这样的点,所以答案要除二
这样我们就把坏三元组的个数弄到了一个点上,这里的复杂度降到了
显然根据容斥的思想减掉坏三元组就是好三元组的数量
好的,接着开始接过之前的话题——怎么在时限内求出一个点的
、
、
、
因为三元组中的点可以重复,所以每个点都有
条出边,
条入边。
所以
至于
、
这个是可以用点分治搞得
好的先进入数论环节
首先在点分的时候推两个距离,
表示从下到上的距离1,
表示从上到下的距离2
那么显然当一条边是好边的时候存在:
感觉
这个东西比较尴尬,我们移下项
其中右项就可以边dfs边处理了,相当于减一下再乘个逆元,逆元表可以和幂次表一起打,反正y一定是质数的话一趟费马小定理就可以了。
那么只要存在
就可以组成一条好边
我们可以把所有得到的vd和右边这个值全部sort一下,然后尺取求一下相等的对数,反正就基本能点分出来
、
了……
之后就按上面的方法计算答案了,总三元组个数为 ,答案就是
代码如下:
#include<map>
#include<set>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
long long cnt,vis[100010],fa[100010],sz[100010],in0[100010],in1[100010],out0[100010],out1[100010],tm[100010],inv[100010],a[100010],x,y,k,n,ans;
vector<int> g[100010];
struct point
{
long long val,id;
} tmp1[100010],tmp2[100010];
int cmp(point a,point b)
{
return a.val<b.val;
}
long long kasumi(long long a,long long b)
{
long long ans=1;
while(b)
{
if(b&1)
{
ans=ans*a%y;
}
a=a*a%y;
b>>=1;
}
return ans;
}
int get_sz(int now,int f)
{
sz[now]=1;
fa[now]=f;
for(int i=0; i<g[now].size(); i++)
{
if(g[now][i]==f||vis[g[now][i]]) continue;
get_sz(g[now][i],now);
sz[now]+=sz[g[now][i]];
}
}
int get_zx(int now,int f)
{
if(sz[now]==1) return now;
int maxson=-1,son;
for(int i=0; i<g[now].size(); i++)
{
if(g[now][i]==f||vis[g[now][i]]) continue;
if(maxson<sz[g[now][i]])
{
maxson=sz[g[now][i]];
son=g[now][i];
}
}
int zx=get_zx(son,now);
while(sz[zx]<2*(sz[now]-sz[zx])) zx=fa[zx];
return zx;
}
int get_val(int now,int f,int dep,long long vu,long long vd)
{
vu=(vu*k+a[now])%y;
if(dep) vd=(vd+a[now]*tm[dep-1])%y;
tmp1[++cnt].id=now;
tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y;
tmp2[cnt].id=now;
tmp2[cnt].val=vd;
for(int i=0; i<g[now].size(); i++)
{
if(g[now][i]==f||vis[g[now][i]]) continue;
get_val(g[now][i],now,dep+1,vu,vd);
}
}
int calc(int now,int kd,int dep,long long vu,long long vd)
{
int posa,posb,tot;
cnt=0;
vu=(vu*k+a[now])%y;
if(dep) vd=(vd+a[now]*tm[dep-1])%y;
tmp1[++cnt].id=now;
tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y;
tmp2[cnt].id=now;
tmp2[cnt].val=vd;
for(int i=0; i<g[now].size(); i++)
{
if(!vis[g[now][i]])
{
get_val(g[now][i],now,dep+1,vu,vd);
}
}
sort(tmp1+1,tmp1+cnt+1,cmp);
sort(tmp2+1,tmp2+cnt+1,cmp);
tot=0;
for(posa=posb=1; posa<=cnt; posa++)
{
while(posb<=cnt&&tmp2[posb].val<=tmp1[posa].val)
{
if(posb==1||tmp2[posb].val!=tmp2[posb-1].val) tot=0;
tot++;
posb++;
}
if(posb!=1&&tmp2[posb-1].val==tmp1[posa].val) out1[tmp1[posa].id]+=tot*kd;
}
tot=0;
for(posa=posb=1; posb<=cnt; posb++)
{
while(posa<=cnt&&tmp1[posa].val<=tmp2[posb].val)
{
if(posa==1||tmp1[posa].val!=tmp1[posa-1].val) tot=0;
tot++;
posa++;
}
if(posa!=1&&tmp1[posa-1].val==tmp2[posb].val) in1[tmp2[posb].id]+=tot*kd;
}
}
int solve(int now)
{
vis[now]=1;
calc(now,1,0,0,0);
for(int i=0; i<g[now].size(); i++)
{
if(vis[g[now][i]]) continue;
calc(g[now][i],-1,1,a[now],0);
}
for(int i=0 ;i<g[now].size(); i++)
{
if(vis[g[now][i]]) continue;
get_sz(g[now][i],0);
int zx=get_zx(g[now][i],0);
solve(zx);
}
}
int main()
{
scanf("%lld%lld%lld%lld",&n,&y,&k,&x);
for(int i=1; i<=n; i++)
{
scanf("%lld",&a[i]);
}
tm[0]=1;
inv[0]=1;
long long invk=kasumi(k,y-2);
for(int i=1; i<=n; i++)
{
tm[i]=tm[i-1]*k%y;
inv[i]=inv[i-1]*invk%y;
}
int from,to;
for(int i=1; i<n; i++)
{
scanf("%d%d",&from,&to);
g[from].push_back(to);
g[to].push_back(from);
}
get_sz(1,0);
int zx=get_zx(1,0);
solve(zx);
for(int i=1; i<=n; i++)
{
in0[i]=n-in1[i];
out0[i]=n-out1[i];
ans+=2*in0[i]*in1[i]+2*out0[i]*out1[i]+in1[i]*out0[i]+in0[i]*out1[i];
}
printf("%lld\n",1ll*n*n*n-ans/2);
}