题意:题意简单的说就是给一个数列,对于每一个数,用在它前面出现的数与它做差,找到对它来说最小的差,答案就是把每一个数的最小差相加。第一个数的最小差是它本身。
分析:这几天在练Splay(感觉自己太菜了),用splay来解决这个问题,维护一颗平衡树,然后依次插入,每次插入找到它的前驱和后继,得到每个数的贡献,最后相加算出答案。最最最简单的Splay应用,然而我还是调了很久的bug,不过没有照着板子抄还是有进步。这几天在网上看大佬写主席树,树链剖分,感觉自己真的太菜了。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define lch tr[x].ch[0]
#define rch tr[x].ch[1]
using namespace std;
const int N = 40000+5;
int n,rt, tot;
int a[N];
struct node
{
int v,ch[2],fa,num;
void set(int vv, int f)
{
v = vv, fa = f;
ch[0] = ch[1] = 0;
num = 1;
}
}tr[N];
void print()
{
for(int x = 1; x <= n; x++)
{
cout <<tr[x].v <<" ss " <<tr[lch].v <<" "<<tr[rch].v << endl;
}
cout <<" -----------------------------" << endl;
}
void rotate(int x, int d)
{
int y = tr[x].fa;
tr[y].ch[!d] = tr[x].ch[d];
tr[tr[x].ch[d]].fa = y;
int z = tr[y].fa;
if(z)
{
tr[z].ch[tr[z].ch[1] == y] =x;
}
tr[x].fa = z;
tr[x].ch[d] = y;
tr[y].fa = x;
}
void splay(int x, int to)
{
while(tr[x].fa != to)
{
int y = tr[x].fa;
if(tr[y].fa == to)
{
rotate(x,tr[y].ch[0] == x);
}
else
{
int z = tr[y].fa;
int d = (tr[z].ch[0] == y);
if(tr[y].ch[d] == x)
{
rotate(x,!d),rotate(x,d);
}
else
{
rotate(y,d); rotate(x,d);
}
}
}
if(to == 0) rt = x;
}
void insert(int v)
{
if(rt == 0)
{
tr[rt = tot ++].set(v,0);
}
else
{
int x = rt, y = 0;
while(x)
{
if(tr[x].v == v)
{
tr[x].num ++;
break;
}
else if(v < tr[x].v) y = x, x = lch;
else y = x, x = rch;
}
if(x == 0)
{
tr[y].ch[v < tr[y].v ? 0 : 1] = tot;
x = tot;
tr[tot++].set(v,y);
}
splay(x,0);
}
}
int suc()
{
int x = rt;
if(rch)
{
x = rch;
while(lch) x = lch;
return tr[x].v;
}
else return -1;
}
int pre()
{
int x = rt;
if(lch)
{
x = lch;
while(rch) x = rch;
return tr[x].v;
}
else return -1;
}
int get()
{
int x = rt;
int ret = tr[x].v;
if(tr[x].num > 1) return 0;
int pr = pre(), su = suc();
if(pr != -1 && su != -1)
{
int a = pr - tr[x].v; a = a > 0 ? a : -a;
int b = su - tr[x].v; b= b > 0 ? b : -b;
ret = min(a,b);
}
else if(pr != -1)
{
int a = pr - tr[x].v; a = a > 0 ? a : -a;
ret = a;
}
else if(su != -1)
{
int b = su - tr[x].v; b= b > 0 ? b : -b;
ret = b;
}
return ret;
}
int main()
{
while(~scanf("%d",&n))
{
rt = 0; tot = 1;
long long ans = 0;
for(int i = 0; i < n; i++)
{
scanf("%d",&a[i]);
insert(a[i]);
ans += get();
}
printf("%lld\n",ans);
}
}