There are n points on a coordinate axis OX. The i-th point is located at the integer point xi and has a speed vi. It is guaranteed that no two points occupy the same coordinate. All n points move with the constant speed, the coordinate of the i-th point at the moment t (t can be non-integer) is calculated as xi+t⋅vi.
Consider two points i and j. Let d(i,j) be the minimum possible distance between these two points over any possible moments of time (even non-integer). It means that if two points i and j coincide at some moment, the value d(i,j) will be 0.
Your task is to calculate the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Input
The first line of the input contains one integer n (2≤n≤2⋅105) — the number of points.
The second line of the input contains n integers x1,x2,…,xn (1≤xi≤108), where xi is the initial coordinate of the i-th point. It is guaranteed that all xi are distinct.
The third line of the input contains n integers v1,v2,…,vn (−108≤vi≤108), where vi is the speed of the i-th point.
Output
Print one integer — the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Examples
inputCopy
3
1 3 2
-100 2 3
outputCopy
3
inputCopy
5
2 1 4 3 5
2 2 2 3 4
outputCopy
19
inputCopy
2
2 1
-3 0
outputCopy
0
题意:
在一个坐标轴上,给你n个点,每个点都有两个属性一个是xi(代表位置),一个是vi代表每秒速度。 每个点移动 xi+t*vi,d(i,j)表示i到j的最短距离。’
问这个公式的最小值是多少。
解析:
xi<xj 且 vi>vj 那么d(i,j)一定等于0,因为在某一时刻可以相遇 这样的情况对答案的贡献为0
xi<xj 且 vi<=vj 那么d(i,j)=abs(xi-xj)。因为不管什么时候i永远无法遇到j 这样的情况对答案的贡献为 abs(xi-xj);
那么现在我假设 位于xi左边的有 x0,x1,x2,x3,x4…xn 且速度都小于 vi
对答案的贡献为 (xi-x0)+(xi-x1)+(xi-x2)+(xi-x3)+(xi-x4)+…+(xi-xn)
整合一下就是 nxi-(x0+x1+x2+x3+x4+…+xn)
这具备了前缀和。所以我们用树状数组维护。
c[0][x]维护x左边数出现的个数
c[1][x]维护x左边数的总和
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+100;
ll c[2][N];
ll b[N];
int n;
struct node
{
ll x,v;
}a[N];
bool cmp(const node &a,const node &b)
{
return a.x<b.x;
}
int lowbit(int x)
{
return x&(-x);
}
ll sum(int x,int k)
{
ll res=0;
while(k)
{
res=res+c[x][k];
k-=lowbit(k);
}
return res;
}
ll add(int x,int val)
{
while(x<=N)
{
c[0][x]++;
c[1][x]+=val;
x+=lowbit(x);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%lld",&a[i].x);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i].v);
b[i]=a[i].v;
}
sort(a+1,a+1+n,cmp);
sort(b+1,b+1+n);
ll m=unique(b+1,b+1+n)-b-1;
ll ans=0;
// for(int i=1;i<=n;i++) cout<<b[i]<<endl;
for(int i=1;i<=n;i++)
{
int x=lower_bound(b+1,b+1+m,a[i].v)-b;
// cout<<x<<endl;
ans=ans+(a[i].x*sum(0,x))-sum(1,x);
add(x,a[i].x);
}
cout<<ans<<endl;
}