树上差分 JLOI 2014 松鼠的新家/luogu 3128 最大流/BJOI 2018 求和

版权声明:未经本蒟蒻同意,请勿转载本蒻博客 https://blog.csdn.net/wddwjlss/article/details/82080915

差分:
一个序列a,每次修改给出一段区间 [ l , r ] ,使得区间内的所有数 + x ,最后求经过 m 次修改后的数列。

我们新开一个差分数组 w ,对于每一次修改我们让 w [ l ] + = x , w [ r + 1 ] = x ,最后对 w 数组求前缀和得到 s u m 数组,让原先的序列中的每一个 a i 分别加上它对应的 s u m i , 最后得到的就是经过 m 次修改后的数列。

树上差分:
m 次操作,每次从点 u 沿树上路径到点 v ,经过的路径上的每个点的权值 + x (包括两个端点)。

我们新开一个差分数组 w ,对于每次操作,我们让 w [ u ] + = x ,     w [ v ] + = x ,     w [ l c a ( u , v ) ] = x ,     w [ f [ l c a ( u , v ) ] ] = x 。然后 d f s 求以每个点为根的子树的 w 值的和,即 w [ u ] + = w [ v ] 。最后求得的 w 数组就是经过 m 次操作后每个点的权值。

前两道题是裸题,第三道题给出树后, m 次询问,每次给出一组 i , j , k ,询问的是从点 i 到点 j 的路径上所有节点深度的 k 次方和。

我们通过 d f s v a l 数组, v a l [ i ] [ k ] 表示 i 1 点路径上所有点深度的 k 次方之和。对于每一次询问,答案就是 v a l [ i ] [ k ] + v a l [ j ] [ k ] v a l [ l c a ( i , j ) ] [ k ] v a l [ f [ l c a ( i , j ) ] ] [ k ]

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=998244353;
struct node
{
    int next,to;
}e[1001000];
int n,m,num,f[505000][23],head[1001000],dep[1001000];
ll val[1001000][51],mi[1001000];
bool book[1001000];
void add(int from,int to)
{
    e[++num].next=head[from];
    e[num].to=to;
    head[from]=num;
}
void dfs(int x,int fa)
{
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;(1<<i)<=dep[x];++i)
        f[x][i]=f[f[x][i-1]][i-1];
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)
            continue;
        dfs(v,x);
    }
}
void get_val(int x)
{
    book[x]=1;
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(book[v])
            continue;
        for(int j=1;j<=50;++j)
            mi[j]=(mi[j-1]%mod*dep[v]%mod)%mod;
        for(int j=1;j<=50;++j)
            val[v][j]=(val[x][j]%mod+mi[j]%mod)%mod;
        get_val(v);
    }
}
int lca(int x,int y)
{
    if(dep[x]<dep[y])
        swap(x,y);
    for(int i=20;i>=0;--i)
        if(dep[x]-dep[y]>=(1<<i))
            x=f[x][i];
    if(x==y)
        return x;
    for(int i=20;i>=0;--i)
        if(dep[x]>=(1<<i)&&f[x][i]!=f[y][i])
        {
            x=f[x][i];
            y=f[y][i];
        }
    return f[x][0];
}
int main()
{
    cin>>n;
    mi[0]=1;
    for(int i=1;i<=n-1;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x); 
    }
    dep[0]=-1;
    dfs(1,0);
    get_val(1);
    cin>>m;
    for(int i=1;i<=m;++i)
    {
        int u,v,d;
        scanf("%d%d%d",&u,&v,&d);
        printf("%lld\n",(val[u][d]%mod+val[v][d]%mod+2*mod-val[lca(u,v)][d]%mod-val[f[lca(u,v)][0]][d]%mod)%mod);
    }
    return 0;
} 

猜你喜欢

转载自blog.csdn.net/wddwjlss/article/details/82080915