UVALive - 6712 lca+dfs序线段树

题意:一棵树q次查询,每次查询给三个不同的点,要求计算到这三个点的比其他两个距离都要小的点数

题解:很明显的lca,倍增的找中点,关键是两个点的中点很好找,但是三个点不好找,我刚开始还准备分类讨论,后来发现巨麻烦,其实可以用线段树来维护算a的答案其实就是a在b下的答案和a在c下的答案的交集,可以用线段树区间求和区间查询做,每次更新完之后复原就不用memset线段树了

//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC optimize("unroll-loops")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000000007
#define C 0.5772156649
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
#define cd complex<double>
#define ull unsigned long long
#define base 1000000000000000000
#define fio ios::sync_with_stdio(false);cin.tie(0)

using namespace std;

const double g=10.0,eps=1e-12;
const int N=100000+10,maxn=5000000+10,inf=0x3f3f3f3f,INF=0x3f3f3f3f3f3f3f3f;

vi v[N];
int dep[N],n,sz[N],fa[20][N];
int le[N],ri[N],id[N],cnt;
void dfs(int u,int f)
{
    le[u]=++cnt;
    id[cnt]=u;
    fa[0][u]=f;
    sz[u]=1;
    for(int i=0;i<v[u].size();i++)
    {
        int x=v[u][i];
        if(x!=f)dep[x]=dep[u]+1,dfs(x,u),sz[u]+=sz[x];
    }
    ri[u]=cnt;
}
int lazy[N<<2],val[N<<2];
void pushdown(int l,int r,int rt)
{
    if(lazy[rt]!=0)
    {
        int m=(l+r)>>1;
        val[rt<<1]+=(m-l+1)*lazy[rt];
        val[rt<<1|1]+=(r-m)*lazy[rt];
        lazy[rt<<1]+=lazy[rt];
        lazy[rt<<1|1]+=lazy[rt];
        lazy[rt]=0;
    }
}
void pushup(int rt)
{
    val[rt]=val[rt<<1]+val[rt<<1|1];
}
void build(int l,int r,int rt)
{
    lazy[rt]=val[rt]=0;
    if(l==r)return ;
    int m=(l+r)>>1;
    build(ls);build(rs);
}
void update(int c,int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)
    {
        val[rt]+=(r-l+1)*c;
        lazy[rt]+=c;
        return ;
    }
    pushdown(l,r,rt);
    int m=(l+r)>>1;
    if(L<=m)update(c,L,R,ls);
    if(m<R)update(c,L,R,rs);
    pushup(rt);
}
int query(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)return val[rt];
    pushdown(l,r,rt);
    int m=(l+r)>>1,ans=0;
    if(L<=m)ans+=query(L,R,ls);
    if(m<R)ans+=query(L,R,rs);
    return ans;
}
void init()
{
    dep[1]=1;
    cnt=0;
    dfs(1,-1);
    build(1,cnt,1);
    for(int i=1;i<20;i++)
        for(int j=1;j<=n;j++)
            fa[i][j]=fa[i-1][fa[i-1][j]];
}
int lca(int x,int y)
{
    if(dep[x]>dep[y])swap(x,y);
    for(int i=0;i<20;i++)
        if((dep[y]-dep[x])>>i&1)
            y=fa[i][y];
    if(x==y)return x;
    for(int i=19;i>=0;i--)
    {
        if(fa[i][x]!=fa[i][y])
        {
            x=fa[i][x];
            y=fa[i][y];
        }
    }
    return fa[0][x];
}
int go(int u,int dis)
{
    for(int i=19;i>=0;i--)
        if(dis>=(1<<i))
            dis-=(1<<i),u=fa[i][u];
    return u;
}
int solve(int a,int b,int c)
{
    int tle,tri,ans=0;
    if(dep[a]>=dep[b])
    {
        int dis=dep[a]+dep[b]-2*dep[lca(a,b)];
        int x=go(a,dis/2);
        if(dis%2==0)x=go(a,dis/2-1);
        update(1,le[x],ri[x],1,cnt,1);
        tle=le[x],tri=ri[x];
    }
    else
    {
        int dis=dep[a]+dep[b]-2*dep[lca(a,b)];
        int x=go(b,dis/2);
        update(1,1,cnt,1,cnt,1);
        update(-1,le[x],ri[x],1,cnt,1);
        tle=le[x],tri=ri[x];
    }
    if(dep[a]>=dep[c])
    {
        int dis=dep[a]+dep[c]-2*dep[lca(a,c)];
        int x=go(a,dis/2);
        if(dis%2==0)x=go(a,dis/2-1);
        ans=query(le[x],ri[x],1,cnt,1);
    }
    else
    {
        int dis=dep[a]+dep[c]-2*dep[lca(a,c)];
        int x=go(c,dis/2);
        ans=query(1,cnt,1,cnt,1);
        ans-=query(le[x],ri[x],1,cnt,1);
    }
    if(dep[a]>=dep[b])update(-1,tle,tri,1,cnt,1);
    else update(-1,1,cnt,1,cnt,1),update(1,tle,tri,1,cnt,1);
    return ans;
}
int main()
{
    int T;scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++)v[i].clear();
        for(int i=1;i<n;i++)
        {
            int a,b;
            scanf("%d%d",&a,&b);
            v[a].pb(b),v[b].pb(a);
        }
        init();
        int q;scanf("%d",&q);
        while(q--)
        {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
//            printf("%d\n",solve(b,a,c));
            printf("%d %d %d\n",solve(a,b,c),solve(b,a,c),solve(c,a,b));
        }
    }
    return 0;
}
/***********************
1
9
1 2
1 3
1 4
2 5
2 6
2 7
6 8
6 9
2
1 2 8
2 1 4
***********************/
View Code

猜你喜欢

转载自www.cnblogs.com/acjiumeng/p/9013715.html