There are n points on a coordinate axis OX. The i-th point is located at the integer point xi and has a speed vi. It is guaranteed that no two points occupy the same coordinate. All n points move with the constant speed, the coordinate of the i-th point at the moment t (t can be non-integer) is calculated as xi+t⋅vi.
Consider two points i and j. Let d(i,j) be the minimum possible distance between these two points over any possible moments of time (even non-integer). It means that if two points i and j coincide at some moment, the value d(i,j) will be 0.
Your task is to calculate the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Input
The first line of the input contains one integer n (2≤n≤2⋅105) — the number of points.
The second line of the input contains n integers x1,x2,…,xn (1≤xi≤108), where xi is the initial coordinate of the i-th point. It is guaranteed that all xi are distinct.
The third line of the input contains n integers v1,v2,…,vn (−108≤vi≤108), where vi is the speed of the i-th point.
Output
Print one integer — the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Examples
input
3
1 3 2
-100 2 3
output
3
input
5
2 1 4 3 5
2 2 2 3 4
output
19
input
2
2 1
-3 0
output
0
树状数组裸题
#include <bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pli pair<long long,int>
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))
using namespace std;
const int N = 2e5 + 10;
pii a[N];
pli c[N];
int b[N], tot, n;
unordered_map<int, int> num;
ll ans;
pli ask(int x) {
pli res = {0, 0};
for (; x; x -= lowbit(x))
res.fi += c[x].fi, res.se += c[x].se;
return res;
}
void add(int x, int y) {
for (; x <= n; x += lowbit(x))c[x].fi += y, c[x].se++;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%d", &a[i].fi);
for (int i = 1; i <= n; i++)scanf("%d", &a[i].se), b[i] = a[i].se;
sort(a + 1, a + n + 1);
sort(b + 1, b + 1 + n);
tot = unique(b + 1, b + 1 + n) - (b + 1);
for (int i = 1; i <= tot; i++)num[b[i]] = i;
add(num[a[1].se], a[1].fi);
for (int i = 2; i <= n; i++) {
pli res = ask(num[a[i].se]);
ans += 1ll * a[i].fi * res.se - res.fi;
add(num[a[i].se], a[i].fi);
}
printf("%lld\n", ans);
return 0;
}