[2020牛客算法竞赛入门课第九节习题] 珂朵莉的数列 离散化+树状数组

题目链接:珂朵莉的数列

题意

n × ( n + 1 ) 2 {\frac{n\times(n+1)}2} 个子区间,求出她们各自的逆序对个数,然后加起来输出。

题解

首先我们来看如何求一段序列中所有的逆序对数,求逆序对,我们可以转化为前i个数里,比a[i]大的有多少个。我们可以像桶排序一样,维护一个长度为 m a x ( a i ) {max(a_i)} 的树状数组,先统计树状数组内当前有多少个数比a[i]大,然后将a[i]更新即可。

知道了如何求逆序对数,这道题我们已经解决了一半,但另一半的难点在于如何求出 n × ( n + 1 ) 2 {\frac{n\times(n+1)}2} 个区间内的逆序对,很明显有些逆序对会被重复计算,所以我们只要设计一个算法,既能统计逆序对个数又能统计逆序对在区间内出现的次数,此题就迎刃而解了。

我们可以通过排列组合发现一个很简单的规律:
假如(i , j)为逆序对时,(i,j)被计算的次数为i*(n-j+1),我们来看这个逆序对的左右两个区间分别为(1,i)和(j,n),易知这两个左右区间的长度为i和(n-j+1),排列组合可知(i,j)这个逆序对被计算的次数为i*(n-j+1)。
那么我们来看 ( i 1 , j ) , ( i 2 , j ) . . . . . ( i x , j ) {(i_1,j),(i_2,j).....(i_x,j)} 为逆序对时,对答案的贡献为 i 1 ( n j + 1 ) + i 2 ( n j + 1 ) + . . . + i x ( n j + 1 ) = s u m ( i ) ( n j + 1 ) {i_1*(n-j+1)+i_2*(n-j+1)+...+i_x*(n-j+1)=sum(i)*(n-j+1)} ,sum(i)为大于j的数的下标之和。

所以我们维护一个下标为a[i]、存的值为i的树状数组,我们可以通过计算sum(n)-sum(a[i])来得出大于i的下标之和,那么对答案的贡献就为 ( s u m ( n ) s u m ( a [ i ] ) ) ( n i + 1 ) {(sum(n)-sum(a[i]))*(n-i+1)} ,然后将a[i]更新即可。

注意!!! 本题a[i]≤1e9,所以无法直接维护下标为a[i]的树状数组,需要进行离散化。还有答案可能会爆long long,需要__int128来存储。
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<bitset>
#include<cassert>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<deque>
#include<iomanip>
#include<list>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
using namespace std;
//extern "C"{void *__dso_handle=0;}
typedef long long ll;
typedef long double ld;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define pll pair<ll,ll>
#define lowbit(x) x&-x

const double PI=acos(-1.0);
const double eps=1e-6;
const ll mod=1e9+7;
const int inf=0x3f3f3f3f;
const int maxn=1e6+10;
const int maxm=100+10;
#define ios ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);


inline void print(__int128 x)
{
	if(x<0)
	{
		putchar('-');
		x=-x;
	}
	if(x>9) print(x/10);
	putchar(x%10+'0');
}

ll n,c[maxn];
ll a[maxn],b[maxn];
void add(int x,int num)
{
	if(x==0) return ;
	while(x<=n)
	{
		c[x]+=num;
		x+=lowbit(x);
	}
}

ll sum(int x)
{
	ll ans=0;
	while(x>0)
	{
		ans+=c[x];
		x-=lowbit(x);
	}
	return ans;
}

int main()
{
	scanf("%lld",&n);
	__int128 ans=0;
	for(int i=1;i<=n;i++)
	{ 
		scanf("%lld",&a[i]);
		b[i]=a[i];
	}
	sort(b+1, b+1+n);
	for(int i=1;i<=n;i++)
		a[i]=lower_bound(b+1, b+1+n, a[i])-b;
	for(int i=1;i<=n;i++)
	{
		ans+=(sum(n)-sum(a[i]))*(n-i+1);
		add(a[i],i);
	}
	print(ans);
}

猜你喜欢

转载自blog.csdn.net/weixin_44235989/article/details/107859530