BZOJ4756:[USACO2017JAN]Promotion Counting

浅谈线段树合并:https://www.cnblogs.com/AKMer/p/10251001.html

题目传送门:https://lydsy.com/JudgeOnline/problem.php?id=4756

对于每个结点用一棵值域线段树维护子树内结点的信息,然后该查询查询该合并合并就好了。

时间复杂度:\(O(nlogn)\)

空间复杂度:\(O(nlogn)\)

代码如下:

#include <cstdio>
#include <algorithm>
using namespace std;

const int maxn=1e5+5;

int n,tot,cnt;
int now[maxn],pre[maxn],son[maxn];
int p[maxn],tmp[maxn],rt[maxn],ans[maxn];

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

void add(int a,int b) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b;
}

struct segment_tree {
    int tot;
    int sum[maxn*20],ls[maxn*20],rs[maxn*20];

    void update(int p) {
        sum[p]=sum[ls[p]]+sum[rs[p]];
    }

    void change(int &p,int l,int r,int pos) {
        if(!p)p=++tot;
        if(l==r) {sum[p]++;return;}
        int mid=(l+r)>>1;
        if(pos<=mid)change(ls[p],l,mid,pos);
        else change(rs[p],mid+1,r,pos);
        update(p);
    }

    int query(int p,int l,int r,int pos) {
        if(l==r)return 0;
        int mid=(l+r)>>1,res=0;
        if(pos<=mid)res=sum[rs[p]]+query(ls[p],l,mid,pos);
        else res=query(rs[p],mid+1,r,pos);
        return res;
    }

    int merge(int a,int b) {
        if(!a||!b)return a+b;
        if(!ls[a]&&!rs[a]&&!ls[b]&&!rs[b]) {
            sum[a]+=sum[b];return a;
        }
        ls[a]=merge(ls[a],ls[b]);
        rs[a]=merge(rs[a],rs[b]);
        update(a);return a;
    }
}T;

void dfs(int fa,int u) {
    int tmp=0;
    for(int P=now[u],v=son[P];P;P=pre[P],v=son[P])
        dfs(u,v),tmp=T.merge(tmp,rt[v]);
    ans[u]=T.query(tmp,1,cnt,p[u]);
    rt[u]=T.merge(rt[u],tmp);
}

int main() {
    n=read();
    for(int i=1;i<=n;i++)
        tmp[i]=p[i]=read();
    sort(tmp+1,tmp+n+1);
    cnt=unique(tmp+1,tmp+n+1)-tmp-1;
    for(int i=1;i<=n;i++) {
        p[i]=lower_bound(tmp+1,tmp+cnt+1,p[i])-tmp;
        T.change(rt[i],1,cnt,p[i]);
    }
    for(int i=2;i<=n;i++) {
        int f=read();add(f,i);
    }
    dfs(0,1);
    for(int i=1;i<=n;i++)
        printf("%d\n",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/AKMer/p/10252186.html