Codechef CHEFFIB 点分树套树状数组

题意

有一棵n个节点的树,初始每个节点的权值均为0。要求资瓷q个操作:
1 u m a b表示对于任意一个节点v,若 d i s ( u , v ) m ,则节点v的权值加上以a和b为开头的斐波那契数列的第 d i s ( u , v ) 项。
2 u询问节点u的权值
n , q 300000

分析

考虑先把点分树建出来,然后在点分树上进行修改和询问,每次只考虑经过分治中心的链带来的贡献。
在修改的时候,我们可以把斐波那契数列往后推到某一项,使得分治中心的位置恰好是第一项,那么现在就只剩下一个距离的限制,用树状数组来维护系数的前缀和即可。
但在对某一个分治中心求贡献时,会把该点所在子树的贡献也算进去,所以我们还要对分治中心的每棵子树开一棵树状数组,然后容斥一下就好了。
时间复杂度 O ( n l o g 2 n )

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define mp(x,y) std::make_pair(x,y)

typedef long long LL;
typedef std::pair<int,int> pi;

const int N=300005;
const int MOD=1000000007;

int n,m,last[N],f[N],cnt,s1,s2,sz,size[N],sum,root,pre[N],a[N],tot,bit1[N*40],bit2[N*40],dep[30][N],bel[30][N],dis[N];
bool vis[N];
struct edge{int to,next;}e[N*2];
struct Matrix{int a[3][3];}po[N];
pi id[30][N];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

void addedge(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
    e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}

void mul(Matrix &c,Matrix a,Matrix b)
{
    memset(c.a,0,sizeof(c.a));
    for (int i=1;i<=2;i++)
        for (int k=1;k<=2;k++)
            for (int j=1;j<=2;j++)
                (c.a[i][j]+=(LL)a.a[i][k]*b.a[k][j]%MOD)%=MOD;
}

void get_root(int x,int fa)
{
    size[x]=1;f[x]=0;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa||vis[e[i].to]) continue;
        get_root(e[i].to,x);
        size[x]+=size[e[i].to];
        f[x]=std::max(f[x],size[e[i].to]);
    }
    f[x]=std::max(f[x],sum-size[x]);
    if (!root||f[x]<f[root]) root=x;
}

void get(int x,int fa,int rt,int d)
{
    dep[dis[rt]][x]=d;size[x]=1;a[++tot]=x;
    for (int i=last[x];i;i=e[i].next)
        if (e[i].to!=fa&&!vis[e[i].to]) get(e[i].to,x,rt,d+1),size[x]+=size[e[i].to];
}

void solve(int x,int d)
{
    vis[x]=1;size[x]=1;dis[x]=d;
    for (int i=last[x];i;i=e[i].next)
    {
        if (vis[e[i].to]) continue;
        tot=0;get(e[i].to,x,x,1);
        size[x]+=size[e[i].to];
        for (int j=1;j<=tot;j++) bel[d][a[j]]=e[i].to;
        id[d][e[i].to]=mp(sz+1,sz+size[e[i].to]+1);
        sz+=size[e[i].to]+1;
    }
    id[d][x]=mp(sz+1,sz+size[x]);
    sz+=size[x];
    for (int i=last[x];i;i=e[i].next)
    {
        if (vis[e[i].to]) continue;
        root=0;sum=size[e[i].to];get_root(e[i].to,x);
        pre[root]=x;
        solve(root,d+1);
    }
}

void ins(int l,int r,int x,int a,int b)
{
    int L=r-l+1;x=r-(std::min(x,L-1)+l)+1;
    while (x<=L)
    {
        (bit1[r-x+1]+=a)%=MOD;
        (bit2[r-x+1]+=b)%=MOD;
        x+=x&(-x);
    }
}

void find(int l,int r,int x)
{
    int L=r-l+1;
    if (x>=L) return;
    x=r-(x+l)+1;
    while (x)
    {
        (s1+=bit1[r-x+1])%=MOD;
        (s2+=bit2[r-x+1])%=MOD;
        x-=x&(-x);
    }
}

void modify(int x,int d,int a,int b)
{
    ins(id[dis[x]][x].first,id[dis[x]][x].second,d,a,b);
    int y=pre[x];
    while (y)
    {
        int L=dep[dis[y]][x];
        if (L<=0) {y=pre[y];continue;}
        ins(id[dis[y]][y].first,id[dis[y]][y].second,d-L,((LL)a*po[L-1].a[1][2]+(LL)b*po[L-1].a[2][2])%MOD,((LL)a*po[L].a[1][2]+(LL)b*po[L].a[2][2])%MOD);
        ins(id[dis[y]][bel[dis[y]][x]].first,id[dis[y]][bel[dis[y]][x]].second,d-L,((LL)a*po[L-1].a[1][2]+(LL)b*po[L-1].a[2][2])%MOD,((LL)a*po[L].a[1][2]+(LL)b*po[L].a[2][2])%MOD);
        y=pre[y];
    }
}

int query(int x)
{
    s1=s2=0;
    find(id[dis[x]][x].first,id[dis[x]][x].second,0);
    int ans=s1,y=pre[x];
    while (y)
    {
        int L=dep[dis[y]][x];
        s1=s2=0;
        find(id[dis[y]][y].first,id[dis[y]][y].second,L);
        (ans+=((LL)s1*po[L-1].a[1][2]+(LL)s2*po[L-1].a[2][2])%MOD)%=MOD;
        s1=s2=0;
        find(id[dis[y]][bel[dis[y]][x]].first,id[dis[y]][bel[dis[y]][x]].second,L);
        (ans+=MOD-((LL)s1*po[L-1].a[1][2]+(LL)s2*po[L-1].a[2][2])%MOD)%=MOD;
        y=pre[y];
    }
    return ans;
}

void clear()
{
    for (int i=1;i<=sz;i++) bit1[i]=bit2[i]=0;
    cnt=sz=0;
    for (int i=1;i<=n;i++) last[i]=vis[i]=0;
}

int main()
{
    int T=read();
    while (T--)
    {
        n=read();m=read();
        clear();
        for (int i=1;i<n;i++)
        {
            int x=read(),y=read();
            addedge(x,y);
        }
        po[1].a[1][2]=po[1].a[2][1]=po[1].a[2][2]=po[0].a[1][1]=po[0].a[2][2]=1;
        for (int i=2;i<=n;i++) mul(po[i],po[i-1],po[1]);
        root=0;sum=n;get_root(1,0);
        solve(1,0);
        while (m--)
        {
            int op=read();
            if (op==1)
            {
                int u=read(),d=read(),a=read(),b=read();
                modify(u,d,a,b);
            }
            else
            {
                int x=read();
                printf("%d\n",query(x));
            }
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_33229466/article/details/80512943