【Brush the title】BZOJ 5293 [Bjoi2018] Summation

Description

The master is very interested in sums over trees. He generates a rooted tree and wishes to ask the tree for the depth k of all nodes on a path along a path multiple times.

power sum, and each time k may be different. The definition of node depth here is the number of edges on the path from this node to the root. he put the question to

pupil, but pupil is not such a complicated operation, can you help him solve it?

Input

The first line contains a positive integer n representing the number of nodes in the tree.

Next n-1 lines of positive integers i,j separated by two spaces each represent an edge on the tree connecting point i and point j.

The next line contains a positive integer m, indicating the number of queries.

After each line, positive integers i,j,k separated by three spaces represent the k-th power sum of the depths of all nodes on the path from point i to point j.

Since this result can be very large, output it modulo 998244353.

The nodes of the tree are numbered from 1, where node 1 is the root of the tree.

Output

For each set of data, output a row of positive integers to represent the result after modulo.

1≤n,m≤300000,1≤k≤50

Sample Input

5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45

Sample Output

33
503245989
Instructions
Example Explanation
The following uses d(i) to represent the depth of the ith node.
For the tree in the example, there are d(1)=0, d(2)=1, d(3)=1, d(4)=2, d(5)=2.
So the answer to the first query is (2^5 + 1^5 + 0^5) mod 998244353 = 33
and the answer to the second query is (2^45 + 1^45 + 2^45) mod 998244353 = 503245989.

Solution

This question is just one water question, the index is less than or equal to 50, directly maintain the weights of two points on the tree for each index and
start to write a tree section, add O2 to Jiaoluo Valley, and then find that the running is a bit slow, and then start again write a difference

Here is the code for the difference:

#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
const int MAXN=300000+10,Mod=998244353;
int n,q,e,to[MAXN<<1],nex[MAXN<<1],beg[MAXN],Jie[MAXN][20],dep[MAXN];
ll Sum[MAXN][51];
template<typename T> inline void read(T &x)
{
    T data=0,w=1;
    char ch=0;
    while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')w=-1,ch=getchar();
    while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
    x=data*w;
}
template<typename T> inline void write(T x,char ch='\0')
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)write(x/10);
    putchar(x%10+'0');
    if(ch!='\0')putchar(ch);
}
template<typename T> inline void chkmin(T &x,T y){x=(y<x?y:x);}
template<typename T> inline void chkmax(T &x,T y){x=(y>x?y:x);}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
inline void insert(int x,int y)
{
    to[++e]=y;
    nex[e]=beg[x];
    beg[x]=e;
}
inline ll qexp(ll a,ll b)
{
    ll res=1;
    while(b)
    {
        if(b&1)res=res*a%Mod;
        a=a*a%Mod;
        b>>=1;
    }
    return res;
}
inline void dfs1(int x,int f)
{
    Jie[x][0]=f;dep[x]=dep[f]+1;
    for(register int i=beg[x];i;i=nex[i])
        if(to[i]==f)continue;
        else dfs1(to[i],x);
    for(register int i=1;i<=50;++i)Sum[x][i]=qexp(dep[x],i);
}
inline void dfs2(int x)
{
    for(register int i=1;i<=50;++i)(Sum[x][i]+=Sum[Jie[x][0]][i])%=Mod;
    for(register int i=beg[x];i;i=nex[i])
        if(to[i]==Jie[x][0])continue;
        else dfs2(to[i]);
}
inline void init()
{
    dfs1(1,0);
    dfs2(1);
    for(register int j=1;j<19;++j)
        for(register int i=1;i<=n;++i)Jie[i][j]=Jie[Jie[i][j-1]][j-1];
}
inline int LCA(int u,int v)
{
    if(dep[u]<dep[v])std::swap(u,v);
    if(dep[u]>dep[v])
        for(register int i=19;i>=0;--i)
            if(dep[Jie[u][i]]>=dep[v])u=Jie[u][i];
    if(u==v)return u;
    for(register int i=19;i>=0;--i)
        if(Jie[u][i]!=Jie[v][i])u=Jie[u][i],v=Jie[v][i];
    return Jie[u][0];
}
int main()
{
    read(n);
    for(register int i=1;i<n;++i)
    {
        int u,v;read(u);read(v);
        insert(u,v);insert(v,u);
    }
    dep[0]=-1;
    init();
    read(q);
    while(q--)
    {
        int u,v,k;read(u);read(v);read(k);
        int lca=LCA(u,v);
        write((Sum[u][k]+Sum[v][k]-Sum[lca][k]+Mod-Sum[Jie[lca][0]][k]+Mod)%Mod,'\n');
    }
    return 0;
}

This is the code of the tree section (you can pass O2, I don't know if you can pass it without O2 or other OJs):

#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
const int MAXN=300000+10,Mod=998244353;
int n,q,e,to[MAXN<<1],nex[MAXN<<1],beg[MAXN],st[MAXN],ed[MAXN],fa[MAXN],dep[MAXN],val[MAXN],hson[MAXN],size[MAXN],cnt,top[MAXN];
#define Mid ((l+r)>>1)
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,l,Mid
#define rson rs,Mid+1,r
struct Segment_Tree{
    ll Sum[MAXN<<2][51];
    inline ll qexp(ll a,ll b)
    {
        ll res=1;
        while(b)
        {
            if(b&1)res=res*a%Mod;
            a=a*a%Mod;
            b>>=1;
        }
        return res;
    }
    inline void PushUp(int rt)
    {
        for(register int i=1;i<=50;++i)Sum[rt][i]=(Sum[ls][i]+Sum[rs][i])%Mod;
    }
    inline void Build(int rt,int l,int r)
    {
        if(l==r)
            for(register int i=1;i<=50;++i)Sum[rt][i]=qexp(val[l],i);
        else
        {
            Build(lson);Build(rson);
            PushUp(rt);
        }
    }
    inline ll Query(int rt,int l,int r,int L,int R,int k)
    {
        if(L<=l&&r<=R)return Sum[rt][k];
        else
        {
            ll res=0;
            if(L<=Mid)(res+=Query(lson,L,R,k))%=Mod;
            if(R>Mid)(res+=Query(rson,L,R,k))%=Mod;
            return res;
        }
    }
};
Segment_Tree T;
#undef Mid
#undef ls
#undef rs
#undef lson
#undef rson
template<typename T> inline void read(T &x)
{
    T data=0,w=1;
    char ch=0;
    while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
    if(ch=='-')w=-1,ch=getchar();
    while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
    x=data*w;
}
template<typename T> inline void write(T x,char ch='\0')
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)write(x/10);
    putchar(x%10+'0');
    if(ch!='\0')putchar(ch);
}
template<typename T> inline void chkmin(T &x,T y){x=(y<x?y:x);}
template<typename T> inline void chkmax(T &x,T y){x=(y>x?y:x);}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
inline void insert(int x,int y)
{
    to[++e]=y;
    nex[e]=beg[x];
    beg[x]=e;
}
inline void dfs1(int x,int f)
{
    int res=0;
    dep[x]=dep[f]+1;size[x]=1;fa[x]=f;
    for(register int i=beg[x];i;i=nex[i])
        if(to[i]==f)continue;
        else
        {
            dfs1(to[i],x);
            size[x]+=size[to[i]];
            if(size[to[i]]>res)res=size[to[i]],hson[x]=to[i];
        }
}
inline void dfs2(int x,int tp)
{
    top[x]=tp;st[x]=++cnt;val[cnt]=dep[x];
    if(hson[x])dfs2(hson[x],tp);
    for(register int i=beg[x];i;i=nex[i])
        if(to[i]==fa[x]||to[i]==hson[x])continue;
        else dfs2(to[i],to[i]);
    ed[x]=cnt;
}
inline void init()
{
    dfs1(1,0);
    dfs2(1,1);
    T.Build(1,1,n);
}
inline ll Getans(int u,int v,int k)
{
    ll res=0;
    while(top[u]!=top[v])
    {
        if(dep[top[u]]<dep[top[v]])std::swap(u,v);
        (res+=T.Query(1,1,n,st[top[u]],st[u],k))+=Mod;
        u=fa[top[u]];
    }
    (res+=T.Query(1,1,n,min(st[v],st[u]),max(st[v],st[u]),k))%=Mod;
    return res;
}
int main()
{
    read(n);
    for(register int i=1;i<n;++i)
    {
        int u,v;read(u);read(v);
        insert(u,v);insert(v,u);
    }
    dep[0]=-1;
    init();
    read(q);
    while(q--)
    {
        int u,v,k;read(u);read(v);read(k);
        write(Getans(u,v,k),'\n');
    }
    return 0;
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325044745&siteId=291194637