P2234 [HNOI2002]营业额统计(Splay数)题解

思路:Splay数查找前驱后继

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define LS(n) node[(n)].ch[0]
#define RS(n) node[(n)].ch[1]
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int maxn = 32767 + 10;
int n;
int cnt;
int root;

struct splay{
    int ch[2], size, cnt, val, fa;
}t[maxn];

int gi(){
    int ans = 0 , f = 1; char i = getchar();
    while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();}
    while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();}
    return ans * f;
}

void out(int x){
    if(t[x].ch[0]) out(t[x].ch[0]);
    printf("%d ",t[x].val);
    if(t[x].ch[1]) out(t[x].ch[1]);
}

int get(int x){
    return t[t[x].fa].ch[1] == x;
}

void up(int x){
    t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}

void rotate(int x){
    int fa = t[x].fa , gfa = t[fa].fa;
    int d1 = get(x) , d2 = get(fa);
    t[fa].ch[d1]=t[x].ch[d1^1] , t[t[x].ch[d1^1]].fa=fa;
    t[gfa].ch[d2]=x , t[x].fa=gfa;
    t[fa].fa=x , t[x].ch[d1^1]=fa;
    up(fa); up(x);
}

void splay(int x,int goal){
    while(t[x].fa != goal){
        int fa = t[x].fa, gfa = t[fa].fa;
        int d1 = get(x), d2 = get(fa);
        if(gfa != goal){
            if(d1 == d2) rotate(fa);
            else rotate(x);
        }
        rotate(x);
    }
    if(goal == 0) root = x;
}

int find(int val){
    int node = root;
    while(t[node].val != val && t[node].ch[t[node].val<val])
        node = t[node].ch[t[node].val<val];
    return node;
}

void insert(int val){
    int node = root, fa = 0;
    while(t[node].val != val && node)
        fa = node, node = t[node].ch[t[node].val<val];
    if(node) t[node].cnt++;
    else{
        node = ++cnt;
        if(fa) t[fa].ch[t[fa].val<val] = node;
        t[node].size = t[node].cnt = 1;
        t[node].fa = fa; t[node].val = val;
    }
    splay(node , 0);
}

//注意,返回的是结构体下标
//注意修改判断时的等于号
int pre(int val,int kind){  //0前驱,1后继
    splay(find(val) , 0); int node = root;
    if(t[node].val <= val && kind == 0) return node;
    if(t[node].val >= val && kind == 1) return node;
    node = t[node].ch[kind];
    while(t[node].ch[kind^1])
        node = t[node].ch[kind^1];
    return node;
}

void delet(int val){
    int last = pre(val,0), next = pre(val,1);
    splay(last , 0); splay(next , last);
    if(t[t[next].ch[0]].cnt > 1){
        t[t[next].ch[0]].cnt--;
        splay(t[next].ch[0] , 0);
    }
    else t[next].ch[0] = 0;
}

int kth(int k){
    int node = root;
    if(t[node].size < k) return INF;
    while(1){
        int son = t[node].ch[0];
        if(k <= t[son].size) node = son;
        else if(k > t[son].size+t[node].cnt){
           k -= t[son].size+t[node].cnt;
            node = t[node].ch[1];
        }
        else return t[node].val;
    }
}

int get_rank(int val){
    splay(find(val) , 0);
    return t[t[root].ch[0]].size;
}
int main(){
    int a;
    root = cnt = 0;
    int ans = 0;
    insert(INF), insert(-INF);
    scanf("%d%d", &n, &a);
    insert(a);
    ans += a;
    for(int i = 1; i <= n - 1; i++){
        scanf("%d", &a);
        int p1 = t[pre(a, 0)].val;
        int p2 = t[pre(a, 1)].val;
        ans += min(abs(p1 - a), abs(p2 - a));
        insert(a);
    }
    printf("%d\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/KirinSB/p/9842821.html