ZJOI2017线段树

ZJOI2017线段树

题意:

给你一颗广义线段树,太长了,自己去看。

题解:

​ 直接上zkw那一套,把闭区间换成开区间,就是把取\([l,r]\),变成取\([l-1,l-1],[r+1,r+1]\)两个端点,往跳,如果\([l-1,l-1]\)往上跳到某一层时,它是它父亲的左儿子,那它的兄弟就是区间中的点。

​ 答案就是(\(u\)是询问的点,\(v\)是区间中的点):
\[ Ans=\sum_{v}dep[v]+dep[u]\times |\{v\}|-2\times sum \]
\(sum\)\(u\)\(\{v\}\)中点的\(lca\)的深度和。然后这有\(3\)种情况,我们把\([l-1,l-1]\)取到的点看做左半部分,\([r+1,r+1]\)取到的看做右半部分,只看左半部分。

  1. \([l-1,l-1]\)\(u\)\(lca\)在所有左半部分点的上方,\(sum\)易得。

  2. \([l-1,l-1]\)\(u\)\(lca\)在所有左半部分点的下方,\(sum\)易得。

  3. \([l-1,l-1]\)\(u\)\(lca\)在一部分左半部分点的上方,相当于把所有点分成上下两个部分,按1、2做。

    注意\(lca\)是区间中的点的情况和\(l=1\)\(r=n\)的情况。
    倍增实现。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fo(i,l,r) for(int i=l;i<=r;i++)
#define of(i,l,r) for(int i=l;i>=r;i--)
#define fe(i,u) for(int i=head[u];i;i=e[i].next)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pii;
#define P(a,b) make_pair(a,b)
inline void open(const char *s)
{
    #ifndef ONLINE_JUDGE
    char str[20];
    sprintf(str,"in%s.txt",s);
    freopen(str,"r",stdin);
//  sprintf(str,"out%s.txt",s);
//  freopen(str,"w",stdout);
    #endif
}
inline int rd()
{
    static int x,f;
    x=0;f=1;
    char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return f>0?x:-x;
}
const int N=200010,NN=N<<1;
int n,m,rt,dep[NN],id[N],ch[NN][2],tim=0;
int bin[20],fa[NN][20],siz[NN][2][20];
ll sum[NN][2][20],ans=0;

void pre(int &u,int l,int r,int fat)
{
    u=++tim;dep[u]=dep[fat]+1;
    if(l==r)return id[l]=u,void();
    int mid=rd();
    pre(ch[u][0],l,mid,u);
    pre(ch[u][1],mid+1,r,u);
}
void dfs(int u,int fat)
{
    if(!u)return;
    fa[u][0]=fat;
    siz[u][0][0]=u==ch[fat][1];
    siz[u][1][0]=u==ch[fat][0];
    sum[u][0][0]=u==ch[fat][1]?dep[ch[fat][0]]:0;
    sum[u][1][0]=u==ch[fat][0]?dep[ch[fat][1]]:0;
    fo(i,1,17){
        fa[u][i]=fa[fa[u][i-1]][i-1];if(!fa[u][i])break;
        siz[u][0][i]=siz[u][0][i-1]+siz[fa[u][i-1]][0][i-1];
        siz[u][1][i]=siz[u][1][i-1]+siz[fa[u][i-1]][1][i-1];
        sum[u][0][i]=sum[u][0][i-1]+sum[fa[u][i-1]][0][i-1];
        sum[u][1][i]=sum[u][1][i-1]+sum[fa[u][i-1]][1][i-1];
    }
    dfs(ch[u][0],u);dfs(ch[u][1],u);
}

inline int getlca(int x,int y)
{
    if(x==y)return x;
    if(dep[x]<dep[y])swap(x,y);
    int d=dep[x]-dep[y];
    fo(i,0,17)if(bin[i]&d)x=fa[x][i];
    if(x==y)return x;
    of(i,17,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}

inline pii get(int x,int lca,int t)
{
    int d=dep[x]-dep[lca]-1,sz=0;
    ll res=0;if(d<0)return P(0,0);
    fo(i,0,17)if(d&bin[i]){
        res+=sum[x][t][i];
        sz+=siz[x][t][i];
        x=fa[x][i];
    }
    return P(res,sz);
}

inline void gaogao(int u,int x,int L,int t)
{
    static pii res;
    int lca=getlca(u,x);
    if(dep[lca]<=dep[L])ans-=2ll*get(x,L,t).second*dep[lca];
    else{
        res=get(x,lca,t);ans-=2ll*res.second*dep[lca];
        if(u==lca&&ch[u][t]&&getlca(x,ch[u][t])==u)ans-=2ll*dep[u];
        else if(ch[lca][t]&&getlca(u,ch[lca][t])!=lca)ans-=2ll*dep[lca]+2ll;
        res=get(lca,L,t);ans-=2ll*res.first-2ll*res.second;
    }
}

inline void gao()
{
    static int u,ql,qr,lca;static pii res;
    u=rd();ql=rd();qr=rd();ans=0;
    if(ql==1&&qr==n)return void(printf("%d\n",dep[u]-1));
    if(ql==1){
        lca=getlca(id[1],id[qr+1]);
        gaogao(u,id[qr+1],lca,0);
        ans-=2ll*dep[getlca(ch[lca][0],u)];
        
        res=get(id[qr+1],fa[lca][0],0);
        ans+=res.first+(ll)dep[u]*res.second;
        return void(printf("%lld\n",ans));
    }
    if(qr==n){
        lca=getlca(id[n],id[ql-1]);
        gaogao(u,id[ql-1],lca,1);
        ans-=2ll*dep[getlca(ch[lca][1],u)];
        
        res=get(id[ql-1],fa[lca][0],1);
        ans+=res.first+(ll)dep[u]*res.second;
        return void(printf("%lld\n",ans));
    }
    lca=getlca(id[ql-1],id[qr+1]);
    gaogao(u,id[ql-1],lca,1);
    gaogao(u,id[qr+1],lca,0);
    
    res=get(id[ql-1],lca,1);
    ans+=res.first+(ll)dep[u]*res.second;
    
    res=get(id[qr+1],lca,0);
    ans+=res.first+(ll)dep[u]*res.second;
    printf("%lld\n",ans);
}

int main()
{
    bin[0]=1;fo(i,1,17)bin[i]=bin[i-1]<<1;
    n=rd();pre(rt,1,n,0);
    dfs(rt,0);
    for(m=rd();m--;)gao();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/JackyhhJuRuo/p/9837142.html