洛谷p3384 树链剖分板子题

下午本来准备学完树链剖分 过完入门题 再去学下莫队的
结果搞了一下午
第一遍代码找不到bug 只好重新写 写完头都是晕的
贴一个别人的教学视频 讲的非常详细
之前也是看这个up的视频 才把主席树基本搞清楚的
https://www.bilibili.com/video/BV1Qt411u77f?from=search&seid=2991831755792450979

因为dfs序和树链剖分啥的视频里面都讲得非常的详细
我就不在这里赘述了直接上代码

#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <math.h>
#include <string.h>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <utility>
#define pi 3.1415926535898
#define ll long long
#define lson rt<<1
#define rson rt<<1|1
#define eps 1e-6
#define ms(a,b) memset(a,b,sizeof(a))
#define legal(a,b) a&b
#define print1 printf("111\n")
using namespace std;
const int maxn = 2e5+10;
const int inf = 0x1f1f1f1f;
const ll llinf =1e17+10;
const int mod = 2333;

int n,m,r,p;
//链式前向星建图
int len,first[maxn];
struct node
{
    int to,next;
}e[maxn];
//记录输入的数据 和 dfs序上对应的值
int a[maxn],wt[maxn];
//线段树建树所用的数组
int t[maxn<<2],lz[maxn<<2],lens[maxn<<2];
//下面的数组用于树链剖分
//数组从左往右分别记录重儿子结点,dfs序,父亲节点,节点深度,子树节点个数,链的头结点
int son[maxn],id[maxn],fa[maxn],dep[maxn],siz[maxn],top[maxn];
int res=0,cnt;

inline void add(int x,int y)
{
    e[len].to=y;
    e[len].next=first[x];
    first[x]=len++;
}

inline void pushdown(int rt)
{
    if(lz[rt])
    {
        lz[lson]+=lz[rt];
        lz[rson]+=lz[rt];
        t[lson]+=lz[rt]*lens[lson];
        t[rson]+=lz[rt]*lens[rson];
        t[lson]%=p;
        t[rson]%=p;
        lz[rt]=0;
    }

}

inline void pushup(int rt)
{
    t[rt]=(t[lson]+t[rson])%p;
}

inline void build(int rt,int l,int r)
{
    lens[rt]=r-l+1;
    if(l==r)
    {
        t[rt]=wt[l];
        if(t[rt]>p)t[rt]%=p;
        return;
    }
    int mid=(l+r)>>1;
    build(lson,l,mid);
    build(rson,mid+1,r);
    pushup(rt);
}

inline void query(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        res+=t[rt];
        res%=p;
        return;
    }
    pushdown(rt);
    int mid=(l+r)>>1;
    if(L<=mid)
        query(lson,l,mid,L,R);
    if(R>mid)
        query(rson,mid+1,r,L,R);
}

inline void updata(int rt,int l,int r,int L,int R,int k)
{
    if(L<=l&&r<=R)
    {
        lz[rt]+=k;
        t[rt]+=k*lens[rt];
        return;
    }
    pushdown(rt);
    int mid=(l+r)>>1;
    if(L<=mid)
        updata(lson,l,mid,L,R,k);
    if(R>mid)
        updata(rson,mid+1,r,L,R,k);
    pushup(rt);
}

inline int qrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])//如果两个节点不在同一个链上 就进行以下操作
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);//先移动深度大一些的点
        res=0;
        query(1,1,n,id[top[x]],id[x]);
        ans+=res;
        ans%=p;
        x=fa[top[x]];
    }
    //在同一个链上的时候从上到下用dfs序遍历
    if(dep[x]>dep[y])swap(x,y);
    res=0;
    query(1,1,n,id[x],id[y]);
    ans+=res;
    return ans%p;
}

//修改操作也跟上面的查询操作相似
inline void updrange(int x,int y,int k)
{
    k%=p;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        updata(1,1,n,id[top[x]],id[x],k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    updata(1,1,n,id[x],id[y],k);
}

inline int qson(int x)
{
    res=0;
    query(1,1,n,id[x],id[x]+siz[x]-1);
    return res;
}

inline void updson(int x,int k)
{
    updata(1,1,n,id[x],id[x]+siz[x]-1,k);
}

//第一遍dfs处理出各个点的深度,父亲节点,儿子个数,还有重儿子结点
inline void dfs1(int x,int f)
{
    dep[x]=dep[f]+1;
    fa[x]=f;
    siz[x]=1;
    int maxsize=-1;
    for(int i=first[x];i!=-1;i=e[i].next)
    {
        int to=e[i].to;
        if(to==f)continue;
        dfs1(to,x);
        siz[x]+=siz[to];
        if(siz[to]>maxsize)
            son[x]=to,maxsize=siz[to];
    }
}

//第二次dfs记录dfs序 但要先遍历重链 在遍历轻链
inline void dfs2(int x,int t)
{
    id[x]=++cnt;
    wt[cnt]=a[x];
    top[x]=t;
    if(!son[x])return;
    dfs2(son[x],t);
    for(int i=first[x];i!=-1;i=e[i].next)
    {
        int to=e[i].to;
        if(to==fa[x]||to==son[x])continue;
        dfs2(to,to);
    }
}

int main()
{
    ms(first,-1);
    scanf("%d%d%d%d",&n,&m,&r,&p);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    while(m--)
    {
        int k,x,y,z;
        scanf("%d",&k);
        if(k==1)
        {
            scanf("%d%d%d",&x,&y,&z);
            updrange(x,y,z);
        }
        if(k==2)
        {
            scanf("%d%d",&x,&y);
            printf("%d\n",qrange(x,y));
        }
        if(k==3)
        {
            scanf("%d%d",&x,&y);
            updson(x,y);
        }
        if(k==4)
        {
            scanf("%d",&x);
            printf("%d\n",qson(x));
        }
    }
}

猜你喜欢

转载自blog.csdn.net/daydreamer23333/article/details/107366128