CodeForces 434E. Furukawa Nagisa's Tree(点分治+容斥+数论)

题目链接

图片加载不出来的看这里
先简述一下题意:
给出一棵树,树上各点都有点权。
定义 S ( s , t ) 为树上 s t 的简单路径的序列。
s 位置的点点权 z 0 ,他到 t 的路径上的下一个点点权为 z 1 ,以此类推……再设一下这条路径的长度为 l
然后定义 G ( S ( s , t ) ) = z 0 k 0 + z 1 k 1 + z 2 k 2 + . . . + z l 1 k l 1
如果一条路径满足 G ( S ( s , t ) ) x ( m o d y ) 那么这条边就是好边,反之如果不满足,这就是条坏边。
定义好三元组 ( x , y , z ) ( x , y ) , ( y , z ) , ( x , z ) 都是好边或坏边的三元组。
求这棵树上有多少对好三元组?(三元组中的点可以重复)
保证y是质数

这题非常的绕……
所以,首先我们先不考虑好边坏边怎么在时限内求,先把怎么判断三元组弄清楚
思考一下,很显然针对三元组来说,只会有这八种情况(0是坏边,1是好边)

首先先考虑直接判断好三元组,你会发现,我们必须要知道三条边的信息才能判断,反正我比较菜这个东西我是想不到 O ( n l o g n ) 的做法,所以要在思考一下别的方法,我们再看坏三元组的特征
坏三元组和好三元组的区别显然是有点连着一条好边和一条坏边,如下图红的点

那么统计一下,假设好的入边为 i n 1 ,坏的为 i n 0 ,同样,假设好的出边为 o u t 1 ,坏的为 o u t 0 ,在一番数数之后,我们得出图中各类点连得边的数量为

in0 out1 2
in1 out0 2
in1 in0 4
out1 out0 4

因为有两个点,所以每个点的贡献为1/2
即一个点会产生 i n 0 i n 1 2 + o u t 0 o u t 1 2 + i n 0 o u t 1 + i n 1 o u t 0 个坏三元组
因为一个坏三元组里有两个这样的点,所以答案要除二
这样我们就把坏三元组的个数弄到了一个点上,这里的复杂度降到了 O ( n )
显然根据容斥的思想减掉坏三元组就是好三元组的数量

好的,接着开始接过之前的话题——怎么在时限内求出一个点的 i n 1 i n 0 o u t 1 o u t 0
因为三元组中的点可以重复,所以每个点都有 n 条出边, n 条入边。
所以 i n 0 = n i n 1 , o u t 0 = n o u t 1
至于 i n 1 o u t 1
这个是可以用点分治搞得
好的先进入数论环节
首先在点分的时候推两个距离, v u = ( p r e v u k + a [ n o w ] ) 表示从下到上的距离1, v d = ( p r e v d + a [ n o w ] k d e e p ) 表示从上到下的距离2

那么显然当一条边是好边的时候存在:
v u + v d k l e n [ v u ] x ( m o d y )
感觉 l e n [ v u ] 这个东西比较尴尬,我们移下项
v d ( x v u ) / k l e n [ v u ] ( m o d y )
其中右项就可以边dfs边处理了,相当于减一下再乘个逆元,逆元表可以和幂次表一起打,反正y一定是质数的话一趟费马小定理就可以了。
那么只要存在 v d ( x v u ) / k l e n [ v u ] ( m o d y ) 就可以组成一条好边
我们可以把所有得到的vd和右边这个值全部sort一下,然后尺取求一下相等的对数,反正就基本能点分出来 i n 1 o u t 1 了……

之后就按上面的方法计算答案了,总三元组个数为 n 3 ,答案就是 n 3 a n s / 2

代码如下:

#include<map>
#include<set>
#include<queue>
#include<cmath>
#include<cstdio>
#include<string>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;

long long cnt,vis[100010],fa[100010],sz[100010],in0[100010],in1[100010],out0[100010],out1[100010],tm[100010],inv[100010],a[100010],x,y,k,n,ans;
vector<int> g[100010];

struct point
{
    long long val,id;
} tmp1[100010],tmp2[100010];

int cmp(point a,point b)
{
    return a.val<b.val;
}

long long kasumi(long long a,long long b)
{
    long long ans=1;
    while(b)
    {
        if(b&1)
        {
            ans=ans*a%y;
        }
        a=a*a%y;
        b>>=1;
    }
    return ans;
}

int get_sz(int now,int f)
{
    sz[now]=1;
    fa[now]=f;
    for(int i=0; i<g[now].size(); i++)
    {
        if(g[now][i]==f||vis[g[now][i]]) continue;
        get_sz(g[now][i],now);
        sz[now]+=sz[g[now][i]];
    }
}

int get_zx(int now,int f)
{
    if(sz[now]==1) return now;
    int maxson=-1,son;
    for(int i=0; i<g[now].size(); i++)
    {
        if(g[now][i]==f||vis[g[now][i]]) continue;
        if(maxson<sz[g[now][i]])
        {
            maxson=sz[g[now][i]];
            son=g[now][i];
        }
    }
    int zx=get_zx(son,now);
    while(sz[zx]<2*(sz[now]-sz[zx])) zx=fa[zx];
    return zx;
}

int get_val(int now,int f,int dep,long long vu,long long vd)
{
    vu=(vu*k+a[now])%y;
    if(dep) vd=(vd+a[now]*tm[dep-1])%y;
    tmp1[++cnt].id=now;
    tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y;
    tmp2[cnt].id=now;
    tmp2[cnt].val=vd;
    for(int i=0; i<g[now].size(); i++)
    {
        if(g[now][i]==f||vis[g[now][i]]) continue;
        get_val(g[now][i],now,dep+1,vu,vd);
    }
}

int calc(int now,int kd,int dep,long long vu,long long vd)
{
    int posa,posb,tot;
    cnt=0;
    vu=(vu*k+a[now])%y;
    if(dep) vd=(vd+a[now]*tm[dep-1])%y;
    tmp1[++cnt].id=now;
    tmp1[cnt].val=(x-vu+y)*inv[dep+1]%y;
    tmp2[cnt].id=now;
    tmp2[cnt].val=vd;
    for(int i=0; i<g[now].size(); i++)
    {
        if(!vis[g[now][i]])
        {
            get_val(g[now][i],now,dep+1,vu,vd);
        }
    }
    sort(tmp1+1,tmp1+cnt+1,cmp);
    sort(tmp2+1,tmp2+cnt+1,cmp);
    tot=0;
    for(posa=posb=1; posa<=cnt; posa++)
    {
        while(posb<=cnt&&tmp2[posb].val<=tmp1[posa].val)
        {
            if(posb==1||tmp2[posb].val!=tmp2[posb-1].val) tot=0;
            tot++;
            posb++;
        }
        if(posb!=1&&tmp2[posb-1].val==tmp1[posa].val) out1[tmp1[posa].id]+=tot*kd;
    }
    tot=0;
    for(posa=posb=1; posb<=cnt; posb++)
    {
        while(posa<=cnt&&tmp1[posa].val<=tmp2[posb].val)
        {
            if(posa==1||tmp1[posa].val!=tmp1[posa-1].val) tot=0;
            tot++;
            posa++;
        }
        if(posa!=1&&tmp1[posa-1].val==tmp2[posb].val) in1[tmp2[posb].id]+=tot*kd;
    }
}

int solve(int now)
{
    vis[now]=1;
    calc(now,1,0,0,0);
    for(int i=0; i<g[now].size(); i++)
    {
        if(vis[g[now][i]]) continue;
        calc(g[now][i],-1,1,a[now],0);
    }
    for(int i=0 ;i<g[now].size(); i++)
    {
        if(vis[g[now][i]]) continue;
        get_sz(g[now][i],0);
        int zx=get_zx(g[now][i],0);
        solve(zx);
    }
}

int main()
{
    scanf("%lld%lld%lld%lld",&n,&y,&k,&x);
    for(int i=1; i<=n; i++)
    {
        scanf("%lld",&a[i]);
    }
    tm[0]=1;
    inv[0]=1;
    long long invk=kasumi(k,y-2);
    for(int i=1; i<=n; i++)
    {
        tm[i]=tm[i-1]*k%y;
        inv[i]=inv[i-1]*invk%y;
    }
    int from,to;
    for(int i=1; i<n; i++)
    {
        scanf("%d%d",&from,&to);
        g[from].push_back(to);
        g[to].push_back(from);
    }
    get_sz(1,0);
    int zx=get_zx(1,0);
    solve(zx);
    for(int i=1; i<=n; i++)
    {
        in0[i]=n-in1[i];
        out0[i]=n-out1[i];
        ans+=2*in0[i]*in1[i]+2*out0[i]*out1[i]+in1[i]*out0[i]+in0[i]*out1[i];
    }
    printf("%lld\n",1ll*n*n*n-ans/2);
}

猜你喜欢

转载自blog.csdn.net/qq_21829533/article/details/82108504
今日推荐