【题解】Luogu P5327 [ZJOI2019]语言

原题传送门

看到这种树上统计点对个数的题一般是线段树合并,这题也不出意外

先对这棵树进行树剖,对于每次普及语言,在\(x,y\)两点的线段树上的\(x,y\)两位置打\(+1\)标记,在点\(fa[lca(x,y)]\)的线段树上\(x,y\)两位置打\(-2\)标记

线段树中该维护三个东西:

1.dfs序最小的\(lp\)

2.dfs序最大的\(rp\)

3.线段树中所有被打标机的点到根节点路径的并的节点个数\(sum\)

我们进行搜索并从下向上的进行线段树合并,对于每个节点,对答案的贡献就是该节点线段树的根节点的\(sum-dep[lca(lp,rp)]+1-1\)(珂以画图理解一下)

交上去,发现wa了,因为每条路径的贡献被我们算了\(2\)次,所以除以2即可

#include <bits/stdc++.h>
#define N 100005 
#define ll long long 
#define getchar nc
using namespace std;
inline char nc(){
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    register int x=0,f=1;register char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    return x*f;
}
inline void write(register ll x)
{
    if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    static int sta[20];register int tot=0;
    while(x)sta[tot++]=x%10,x/=10;
    while(tot)putchar(sta[--tot]+48);
}
int n,m;
ll ans;
struct edge{
    int to,next;
}e[N<<1];
int head[N],cnt=0;
inline void add(register int u,register int v)
{
    e[++cnt]=(edge){v,head[u]};
    head[u]=cnt;
}
int size[N],fa[N],son[N],dep[N],top[N],dfn[N],tim,idfn[N];
inline void dfs1(register int x)
{
    size[x]=1;
    for(register int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa[x])
            continue;
        dep[v]=dep[x]+1;
        fa[v]=x;
        dfs1(v);
        size[x]+=size[v];
        if(size[v]>size[son[x]])
            son[x]=v;
    }
}
inline void dfs2(register int x,register int t)
{
    dfn[x]=++tim;
    idfn[tim]=x;
    top[x]=t;
    if(son[x])
        dfs2(son[x],t);
    for(register int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v!=fa[x]&&v!=son[x])
            dfs2(v,v);
    }
}
inline int getlca(register int x,register int y)
{
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]<dep[fy])
        {
            x^=y^=x^=y;
            fx^=fy^=fx^=fy;
        }
        x=fa[fx];
        fx=top[x];
    }
    return dep[x]<dep[y]?x:y;
}
int root[N],tot;
struct node{
    int ls,rs,cnt,lp,rp,sum;
}tr[N*60];
inline void pushup(register int x)
{
    tr[x].lp=tr[tr[x].ls].lp?tr[tr[x].ls].lp:tr[tr[x].rs].lp;
    tr[x].rp=tr[tr[x].rs].rp?tr[tr[x].rs].rp:tr[tr[x].ls].rp;
    tr[x].sum=tr[tr[x].ls].sum+tr[tr[x].rs].sum;
    if(tr[tr[x].ls].rp&&tr[tr[x].rs].lp)
        tr[x].sum-=dep[getlca(tr[tr[x].ls].rp,tr[tr[x].rs].lp)];
}
inline void modify(register int &x,register int l,register int r,register int pos,register int val)
{
    if(!x)
        x=++tot;
    if(l==r)
    {
        tr[x].cnt+=val;
        if(tr[x].cnt)
            tr[x].lp=tr[x].rp=pos,tr[x].sum=dep[pos];
        else
            tr[x].lp=tr[x].rp=tr[x].sum=0;
        return;
    }
    int mid=l+r>>1;
    if(dfn[pos]<=mid)
        modify(tr[x].ls,l,mid,pos,val);
    else
        modify(tr[x].rs,mid+1,r,pos,val);
    pushup(x);
}
inline void merge(register int &a,register int b,register int l,register int r)
{
    if(!a||!b)
    {
        a=a+b;
        return;
    }
    if(l==r)
    {
        tr[a].cnt+=tr[b].cnt;
        if(tr[a].cnt)
            tr[a].lp=tr[a].rp=idfn[l],tr[a].sum=dep[idfn[l]];
        else
            tr[a].lp=tr[a].rp=tr[a].sum=0;
        return;
    }
    int mid=l+r>>1;
    merge(tr[a].ls,tr[b].ls,l,mid);
    merge(tr[a].rs,tr[b].rs,mid+1,r);
    pushup(a);
}
inline void dfs(register int x)
{
    for(register int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa[x])
            continue;
        dfs(v);
        merge(root[x],root[v],1,tim);
    }
    if(tr[root[x]].lp&&tr[root[x]].rp)
        ans+=tr[root[x]].sum-dep[getlca(tr[root[x]].lp,tr[root[x]].rp)];
}
int main()
{
    n=read(),m=read();
    for(register int i=1;i<n;++i)
    {
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs1(1);
    dfs2(1,0);
    for(register int i=1;i<=m;++i)
    {
        int x=read(),y=read();
        int lca=getlca(x,y);
        modify(root[x],1,tim,x,1);
        modify(root[x],1,tim,y,1);
        modify(root[y],1,tim,x,1);
        modify(root[y],1,tim,y,1);
        if(lca>1)
        {
            modify(root[fa[lca]],1,tim,x,-2);
            modify(root[fa[lca]],1,tim,y,-2);
        }
    }
    dfs(1);
    write(ans>>1);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yzhang-rp-inf/p/11223298.html