tree - dp - 长链剖分

题目大意:
给你一颗树,点有点权,对所有三元组(x,y,z),满足dis(x,y)=dis(y,z)=dis(x,z),统计a(x)a(y)+a(x)a(z)+a(y)a(z)的和。n<=100000。
题解:
条件等价于存在一个中心点。
枚举三个点的LCA,然后劈成两半,一半是链,一半是Y倒过来写,发现二者能合并当且仅当链长等于Y倒过来写的下面长度减去上面长度,而这二者不超过子树深度,因此长链剖分即可。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
const int N=100010;
struct edges{
    int to,pre;
}e[N<<1];int h[N],etop,a[N],l[N],son[N];lint ans=0;
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
inline int inn()
{
    int x,ch;while((ch=gc)<'0'||ch>'9');
    x=ch^'0';while((ch=gc)>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^'0');return x;
}
int getl(int x,int fa=0)
{
    l[x]=1,son[x]=0;
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^fa)
        {
            l[x]=max(l[x],getl(y,x)+1);
            if(l[y]>l[son[x]]) son[x]=y;
        }
    return l[x];
}
inline int *arr(int n) { int *p=new int[n];return memset(p,0,sizeof(int)*n),p; }
#define P(x) (x>=mod?x-=mod:0)
int dfs(int x,int fa,int *Ax,int *Bx,int *Cx,int *Dx)
{
    if(son[x]) dfs(son[x],x,Ax-1,Bx-1,Cx+1,Dx+1),ans+=(Ax[0]+(lint)a[x]*Bx[0])%mod,P(ans);
    Cx[0]=a[x],Dx[0]=1;
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)!=fa&&e[i].to!=son[x])
        {
            int *Ay=arr(l[y]*2+1)+l[y],*By=arr(l[y]*2+1)+l[y],
                *Cy=arr(l[y]+1),*Dy=arr(l[y]+1);
            dfs(y,x,Ay,By,Cy,Dy);
            rep(d,0,l[y])
            {
                if(d) ans+=((lint)Dx[d-1]*Ay[d]+(lint)Cx[d-1]*By[d])%mod,P(ans);
                ans+=((lint)Ax[d+1]*Dy[d]+(lint)Bx[d+1]*Cy[d])%mod,P(ans);
            }
            rep(d,0,l[y]-1) Ax[d]+=Ay[d+1],Bx[d]+=By[d+1],P(Ax[d]),P(Bx[d]);
            rep(d,1,l[y]) Ax[d]+=(lint)Cx[d]*Cy[d-1]%mod,P(Ax[d]),
                          Bx[d]+=((lint)Cx[d]*Dy[d-1]+(lint)Dx[d]*Cy[d-1])%mod,P(Bx[d]);
            rep(d,0,l[y]) Cx[d+1]+=Cy[d],Dx[d+1]+=Dy[d],P(Cx[d+1]),P(Dx[d+1]);
        }
    return 0;
}
int main()
{
    int n=inn(),u,v;n=inn();
    rep(i,1,n-1) u=inn(),v=inn(),add_edge(u,v),add_edge(v,u);
    rep(i,1,n) a[i]=inn(),(a[i]>=mod?a[i]%=mod:0);
    getl(1);
    int *A=arr(l[1]*2+1)+l[1],*B=arr(l[1]*2+1)+l[1],
        *C=arr(l[1]+1),*D=arr(l[1]+1);
    return dfs(1,0,A,B,C,D),!printf("%d\n",int(ans));
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/84305389