[luogu1600]NOIp2016D1T2 天天爱跑步

题目链接:
luogu1600

谨以此题纪念那段年少无知但充满趣味的恬淡时光

附上一位dalao的博客链接:https://www.luogu.org/blog/user26242/ke-pa-di-tian-tian-ai-pao-bu

我写这道题的时候脑袋里一直想的是他还没AFO的时候的那段日子,真是快乐啊,虽然我那时什么都不会

好了废话结束

他写的是最为常见的差分+\(LCA\)+桶的写法,在此不再赘述

其实这是一道线段树合并的基础

常规套路:我们将路径\((s,t)\)拆成\((s,lca)\)\((lca,t)\),在这两段路径中

1)如果在\((s,lca)\)上的一点\(i\)看到,则\(w_i=dep_s-dep_i\),即\(w_i+dep_i=dep_s\)

2)如果在\((lca,t)\)上的一点\(i\)看到,则\(w_i=dep_s-dep_{lca}+dep_i-dep_{lca}\),即\(w_i-dep_i=dep_s-dep_{lca}*2\)

也就是说我们对于树上的某点\(i\),需要统计它的子树中满足上两个式子中某一个的点的个数

这个可以用权值线段树的合并维护

同时考虑到该路径不会对\(lca\)以上部分的答案产生影响,故考虑树上差分;具体的,在\(s\)\(t\)处放上\(+1\)标记,而在\(fa_{lca}\)放上两个\(-1\)标记

代码注意事项

1)第二个式子中可能会出现负数,因此将值域整体向右平移\(n\)个单位

2)当这条路径的\(lca\)是一个可以被统计到答案的点时,它会被统计两次(带回发现它对两个式子均成立)。这可以在打\(tag\)时直接判断

代码如下(常数很大,luogu不吸氧\(5317ms\)

#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
#include<set>
using namespace std;
#define lowbit(x) (x)&(-x)
#define rep(i,a,b) for (int i=a;i<=b;i++)
#define per(i,a,b) for (int i=a;i>=b;i--)
#define maxd 1000000007
typedef long long ll;
const int N=100000;
const double pi=acos(-1.0);
struct edgenode{
    int to,nxt;
}sq[600200];
int n,m,all=0,head[300300],w[300300],dep[300300],fa[300300][20],ans[300300];

int read()
{
    int x=0,f=1;char ch=getchar();
    while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
    while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
    return x*f;
}

struct segment_tree{
    struct node{
        int cnt,lson,rson;
    }seg[10000800];
    int tot,root[300100];
    
    void insert(int &id,int l,int r,int pos,int val)
    {
        if (!id) id=(++tot);seg[id].cnt+=val;
        if (l==r) return;
        int mid=(l+r)>>1;
        if (pos<=mid) insert(seg[id].lson,l,mid,pos,val);
        else insert(seg[id].rson,mid+1,r,pos,val);
    }
    
    int merge(int x,int y,int l,int r)
    {
        if ((!x) || (!y)) return x+y;
        seg[x].cnt+=seg[y].cnt;
        int mid=(l+r)>>1;
        seg[x].lson=merge(seg[x].lson,seg[y].lson,l,mid);
        seg[x].rson=merge(seg[x].rson,seg[y].rson,mid+1,r);
        return x;
    }
    
    int query(int id,int l,int r,int pos)
    {
        if (l==r) return seg[id].cnt;
        int mid=(l+r)>>1;
        if (pos<=mid) return query(seg[id].lson,l,mid,pos);
        else return query(seg[id].rson,mid+1,r,pos);
    }
}seg1,seg2;

void add(int u,int v)
{
    all++;sq[all].to=v;sq[all].nxt=head[u];head[u]=all;
}

int query_lca(int u,int v)
{
    if (dep[u]<dep[v]) swap(u,v);
    int tmp=dep[u]-dep[v];
    rep(i,0,19)
        if ((tmp>>i)&1) u=fa[u][i];
    if (u==v) return u;
    per(i,19,0)
        if (fa[u][i]!=fa[v][i]) {u=fa[u][i];v=fa[v][i];}
    return fa[u][0];
}

void dfs1(int u,int fu)
{
    dep[u]=dep[fu]+1;fa[u][0]=fu;
    rep(i,1,19) fa[u][i]=fa[fa[u][i-1]][i-1];
    int i;
    for (i=head[u];i;i=sq[i].nxt)
    {
        int v=sq[i].to;
        if (v==fu) continue;
        dfs1(v,u);
    }
}

void dfs2(int u,int fu)
{
    int i;
    for (i=head[u];i;i=sq[i].nxt)
    {
        int v=sq[i].to;
        if (v==fu) continue;
        dfs2(v,u);
        seg1.root[u]=seg1.merge(seg1.root[u],seg1.root[v],0,n);
        seg2.root[u]=seg2.merge(seg2.root[u],seg2.root[v],0,n*2);
    }
    if ((w[u]+dep[u]>=0) && (w[u]+dep[u]<=n)) 
        ans[u]+=seg1.query(seg1.root[u],0,n,w[u]+dep[u]);
    if ((w[u]-dep[u]>=-n) && (w[u]-dep[u]<=n))
        ans[u]+=seg2.query(seg2.root[u],0,2*n,w[u]-dep[u]+n);
}

int main()
{
    n=read();m=read();
    seg1.tot=0;seg2.tot=0;
    rep(i,1,n-1)
    {
        int u=read(),v=read();
        add(u,v);add(v,u);
    }
    dfs1(1,0);
    rep(i,1,n) w[i]=read();
    rep(i,1,m)
    {
        int s=read(),t=read(),lca=query_lca(s,t),fal=fa[lca][0];
        seg1.insert(seg1.root[s],0,n,dep[s],1);
        seg1.insert(seg1.root[fal],0,n,dep[s],-1);
        seg2.insert(seg2.root[t],0,n*2,dep[s]-2*dep[lca]+n,1);
        seg2.insert(seg2.root[fal],0,n*2,dep[s]-2*dep[lca]+n,-1);
        if (w[lca]+dep[lca]==dep[s]) ans[lca]--;
    }
    dfs2(1,0);
    rep(i,1,n) printf("%d ",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zhou2003/p/10759686.html