bzoj1112: [POI2008]砖块Klo(splay)

题面在这里

做法

枚举每长度为 k 的段寻找中位数即可。splay维护。

代码

=> 主要是想说这一点,由于计算的必要,相同的数不能合并到一个节点,否则之后调用 sum[ch[x][0]]/sum[ch[x][1]] 的时候会漏算和节点 x 相同的数。

#include<bits/stdc++.h>
#define rep(i,x,y) for (int i=(x); i<=(y); i++)
#define ll long long
#define ld long double
#define inf 1000000000
#define INF 1000000000000000000ll
using namespace std;
ll read(){
    char ch=getchar(); ll x=0; int op=1;
    for (; !isdigit(ch); ch=getchar()) if (ch=='-') op=-1;
    for (; isdigit(ch); ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    return x*op;
}
#define N 100005
int n,m,a[N],data[N],fa[N],ch[N][2],siz[N],rt,tot; ll ans,sum[N];
void up(int x){
    sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+data[x];
    siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
}
void rot(int x){
    int y=fa[x],z=fa[y],f=ch[y][1]==x;
    ch[y][f]=ch[x][f^1]; if (ch[x][f^1]) fa[ch[x][f^1]]=y;
    fa[x]=z; if (z) ch[z][ch[z][1]==y]=x;
    fa[y]=x; ch[x][f^1]=y; up(y); up(x);
}
void splay(int x,int tp){
    while (fa[x]!=tp){
        int y=fa[x],z=fa[y];
        if (z!=tp) rot((ch[z][0]==y)==(ch[y][0]==x)?y:x);
        rot(x);
    }
    if (!tp) rt=x;
}
void insert(int val){
    int x=rt;
    if (!rt){
        rt=x=++tot;
        ch[x][0]=ch[x][1]=fa[x]=0;
        data[x]=sum[x]=val; siz[x]=1;
        return;
    }
    while (x){
        int &y=ch[x][val>data[x]];
        if (!y){
            y=++tot;
            ch[y][0]=ch[y][1]=0; fa[y]=x;
            data[y]=sum[y]=val; siz[y]=1;
            x=y; break;
        }
        x=y;
    }
    splay(x,0);
}
int find(int val){
    int x=rt;
    while (ch[x][val>data[x]] && data[x]!=val) x=ch[x][val>data[x]];
    splay(x,0); return x;
}
int getpre(int val){
    int x=find(val); if (data[x]<val) return x;
    x=ch[x][0];
    while (ch[x][1]) x=ch[x][1];
    return x;
}
int getnxt(int val){
    int x=find(val); if (data[x]>val) return x;
    x=ch[x][1];
    while (ch[x][0]) x=ch[x][0];
    return x;
}
void delet(int val){
    int x=getpre(val),y=getnxt(val);
    splay(x,0); splay(y,x);
    int &z=ch[y][0];
    data[z]=sum[z]=siz[z]=0,z=0,splay(y,0);
}
int getkth(int k){
    int x=rt;
    while (1){
        if (k<=siz[ch[x][0]]) x=ch[x][0];
        else if (k>siz[ch[x][0]]+1) k-=siz[ch[x][0]]+1,x=ch[x][1];
        else return x;
    }
}
int main(){
    n=read(),m=read();
    rep (i,1,n) a[i]=read();
    insert(-inf); insert(inf);
    rep (i,1,m-1) insert(a[i]);
    ans=INF;
    rep (i,m,n){
        insert(a[i]);
        int tmp=getkth((m+1)/2+1);//返回节点编号
        splay(tmp,0);
        ans=min(ans,(ll)data[tmp]*((m+1)/2-1)-(sum[ch[tmp][0]]+inf)+(sum[ch[tmp][1]]-inf)-(ll)data[tmp]*(m-(m+1)/2));
        delet(a[i-m+1]);
    }
    cout<<ans<<endl;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/bestfy/article/details/80107261
今日推荐