2020牛客暑期多校训练营(第四场)A Ancient Distance —— 整除分块类似,线段树,想法,有丶东西

This way

题意:

现在有一棵树,假设你现在可以设置树中k个点为关键点,使得最后每个点到根的路径中距离它最近的关键点的距离的最大值最小。问你k从1到n的所有答案的和。

题解1:

这道题有点难,牛客的标准解法我先不做,因为这个代码写起来很长,然后我就去找有没有比较短一点的做法,于是我就找到了一个很厉害的做法,分块上做线段树,这个时间复杂度应该也是 O ( n l o g 2 n ) O(nlog^2n) 级别的。
首先我们要知道,当答案一定的时候,是会有多种k的,比如说一条长度为4的链,放2个关键点和3个关键点的答案是一样的,都是1.于是利用这个特性,在这个上面做线段树,每次暴力的查看答案为mid的时候 ,要放多少个点。
lenl表示当前答案的左界,lenr表示当前答案的右界
kl表示当前放的关键点数量的左界,kr表示当前放的关键点数量的右界。
最后如果kl=kr,那么就表示长度为lenl的需要的关键点的数量是kl。
看起来这个代码是 O ( n 2 l o g n ) O(n^2logn) 的,但是用到了答案不同时k的不连续性,将其对k进行了不可描述的分块处理,于是降下了时间。
我的理解是这样的,应该是对的吧,太强了QAQ

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e5+5;
vector<int>vec[N];
int dep[N],tmp[N],fa[N],pos[N],tim;
void dfs(int x){
    pos[++tim]=x;
    for(auto i:vec[x]){
        dep[i]=dep[x]+1;
        dfs(i);
    }
}
ll ans[N];
int n;
void build(int lenl,int lenr,int kl,int kr){
    if(lenl>lenr||kl>kr)return ;
    if(kl==kr){
        for(int i=lenl;i<=lenr;i++)ans[i]=kl;
        return ;
    }
    int mid=lenl+lenr>>1;
    ans[mid]=0;
    for(int i=1;i<=n;i++)tmp[i]=dep[i];
    for(int i=n;i;i--){
        int x=pos[i];
        if(tmp[x]==dep[x]+mid||x==1)
            ans[mid]++,tmp[x]=-1;
        tmp[fa[x]]=max(tmp[fa[x]],tmp[x]);
    }
    build(lenl,mid-1,ans[mid],kr);
    build(mid+1,lenr,kl,ans[mid]);
}
int main()
{
    while(~scanf("%d",&n)){
        tim=0;
        for(int i=1;i<=n;i++)
            vec[i].clear();
        for(int i=2;i<=n;i++)
            scanf("%d",&fa[i]),vec[fa[i]].push_back(i);
        dfs(1);
        int mx=0;
        for(int i=1;i<=n;i++)mx=max(mx,dep[i]);
        build(0,mx,1,n);
        ll sum=0;
        for(ll i=1;i<=mx;i++)
            sum+=i*(ans[i-1]-ans[i]);
        printf("%lld\n",sum);
    }
    return 0;
}

题解2:

最终我还是去做了牛客给的标准题解。我就简略的说一下
大致是从小到大枚举答案,也就是最小值最大,然后用这棵树的dfs序构造线段树。枚举到当前的答案的时候,找到最深的点,然后用倍增往上找i步,然后将这个点的子树的区间删除,直到根被删除为止。
我用del数组维护每个区间是否被删除。但是有一个很奇怪的问题就是我如果push_down的话,会出错。我觉得可能是在update中的往上传值的时候出现了问题吧,并且这道题其实不需要向下更新,因为不会有访问下面区间的机会。
但是我还是想要知道该怎么样才能加上push_down,如果有大佬知道的话指点一下谢谢了。
在每次做完之后,再重复一遍删的位置,将del数组恢复。

#include<bits/stdc++.h>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
#define ll long long
#define pa pair<int,int>
const int N=2e5+5;
pa mx[N*4];
int del[N*4],f[N*4];
int sta[N],en[N],pos[N];
vector<int>vec[N];
int dep[N],fa[N][25],tim,n;
void dfs(int x){
    sta[x]=++tim;
    pos[tim]=x;
    for(auto i:vec[x]){
        dep[i]=dep[x]+1;
        dfs(i);
        en[x]=en[i];
    }
    if(!vec[x].size())en[x]=sta[x];
}
/*
void push_down(int root){
    if(f[root]==-1)return ;
    del[root<<1]=del[root<<1|1]=f[root];
    f[root<<1]=f[root<<1|1]=f[root];
    f[root]=-1;
}
*/
void build(int l,int r,int root){
    f[root]=-1;
    del[root]=0;
    if(l==r){
        mx[root]={pos[l],dep[pos[l]]};
        return ;
    }
    int mid=l+r>>1;
    build(l,mid,root<<1);
    build(mid+1,r,root<<1|1);
    if(mx[root<<1].second>mx[root<<1|1].second)
        mx[root]=mx[root<<1];
    else
        mx[root]=mx[root<<1|1];
}
void update(int l,int r,int root,int ql,int qr,int op){
    if(l>=ql&&r<=qr){
        if(op==1)
            del[root]=f[root]=1;
        else
            del[root]=f[root]=0;
        return ;
    }
    //push_down(root);
    int mid=l+r>>1;
    if(mid>=ql)
        update(l,mid,root<<1,ql,qr,op);
    if(mid<qr)
        update(mid+1,r,root<<1|1,ql,qr,op);
    if(del[root<<1]&&del[root<<1|1])
        del[root]=1;
    else if(del[root<<1])
        mx[root]=mx[root<<1|1],del[root]=0;
    else if(del[root<<1|1])
        mx[root]=mx[root<<1],del[root]=0;
    else{
        if(mx[root<<1].second>mx[root<<1|1].second)
            mx[root]=mx[root<<1];
        else
            mx[root]=mx[root<<1|1];
        del[root]=0;
    }
}
void deal(){
    for(int j=1;(1<<j)<=n;j++)
        for(int i=1;i<=n;i++)
            fa[i][j]=fa[fa[i][j-1]][j-1];
}
int finds(int x,int step){
    int d=max(dep[x]-step,1);
    for(int i=20;i>=0;i--)
        if(dep[fa[x][i]]>=d)
            x=fa[x][i];
    return x;
}
vector<int>v;
ll ans[N];
int main()
{
    while(~scanf("%d",&n)){
        tim=0;
        for(int i=0;i<=n;i++){
            vec[i].clear(),ans[i]=0;
            for(int j=0;j<=20;j++)
                fa[i][j]=0;
        }
        for(int i=2;i<=n;i++)
            scanf("%d",&fa[i][0]),vec[fa[i][0]].push_back(i);
        dep[1]=1;
        dfs(1);
        deal();
        int mm=0;
        for(int i=1;i<=n;i++)mm=max(mm,dep[i]);
        build(1,n,1);
        ans[0]=n;
        for(int i=1;i<mm;i++){
            v.clear();
            while(!del[1]){
                pa u=mx[1];
                int f=finds(u.first,i);
                update(1,n,1,sta[f],en[f],1);
                v.push_back(f);
                ans[i]++;
            }
            for(int j:v)
                update(1,n,1,sta[j],en[j],2);
        }
        ll sum=0;
        for(ll i=1;i<mm;i++)
            sum+=i*(ans[i-1]-ans[i]);
        printf("%lld\n",sum);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/tianyizhicheng/article/details/107512243