详细讲解Codeforces Round #624 (Div. 3) F. Moving Points

 题意:给定n个点的初始坐标x和速度v(保证n个点的初始坐标互不相同), d(i,j)是第i个和第j个点之间任意某个时刻的最小距离,求出n个点中任意一对点的d(i,j)的总和。

题解:可以理解,两个点中初始坐标较小的点的速度更大时,总有一个时刻后面的点会追上前面的点,d(i,j) =0。

   否则,即后面的点的速度 <= 前面的点的速度时,两点之间的距离只会越来越大,d(i,j) = abs(xi - xj) (初始距离)。

可以用直线来辅助理解:x = xi + v*t,横轴为t,纵轴为x,若两直线交点t值大于等于0,则d(i,j) = 0。否则交点t值为负 或者两直线平行时,d(i,j)=初始距离。

     所以,立即会想到对点按初始坐标排序,遍历每个点,计算出前面点中 速度小于等于 当前点的 所有点与当前点的初始距离总和。n可达2*10^5,需要找O(n log(n))的算法。

  若当前点下标为i,前面所有速度不大于当前点的点下标为j1,j2,...,相当于求(x[i]-x[j1])+(x[i]-x[j2])+(x[i]-x[j3])... = num * x[i] - sum(x[j])。即需要使用一个数据结构来维护前面速度较小点的数量 和 初始距离x的总和。

  最佳选择就是树状数组,按初始坐标递增的顺序依次添加点的信息,一个树状数组记录小于等于当前速度的点的个数,另一个记录这些点的初始距离总和。

  由于速度范围比较大,需要进行离散化处理,即把n个速度离散成n_个下标,树状数组正好对应这个下标。

详细见代码和注释如下:

 1 #include<cstdio>
 2 #include<utility>      //pair
 3 #include<algorithm>    //sort
 4 #include<vector>      //lower_bound, unique
 5 using namespace std;
 6 
 7 const int maxn = 2e5 + 2;
 8 pair<int, int>a[maxn];    //存所有点的(初始坐标,速度)
 9 int v[maxn], n;            //所有点的速度,点的个数
10 
11 long long s1[maxn], s2[maxn];    //两个树状数组
12 void add(int i, int x) {
13     while (i <= n) {
14         s1[i]++;        //s1存个数,每次增加1
15         s2[i] += x;        //s2存初始坐标x的总和,每次增加x
16         i += i & (-i);
17     }
18 }
19 
20 long long getSum(long long s[], int i) {
21     long long res(0);
22     while (i > 0) {
23         res += s[i];
24         i -= i & (-i);
25     }
26     return res;
27 }
28 
29 int main() {
30     scanf("%d", &n);
31     for (int i = 1; i <= n; i++)scanf("%d", &a[i].first);
32     for (int i = 1; i <= n; i++) {
33         scanf("%d", v + i);
34         a[i].second = v[i];
35     }
36     sort(a + 1, a + n + 1);
37     sort(v + 1, v + n + 1);
38     int *vend = unique(v + 1, v + n + 1);    //速度离散化为 v[1]到v[vend-1 - v]
39     long long ans = 0LL;
40     for (int i = 1; i <= n; i++) {    //按x递增顺序向树状数组中添加点的信息,初始两个树状数组都为空
41         //得到x第i小的点 的v值 在v数组中对应的下标pos
42         int pos = lower_bound(v + 1, vend, a[i].second) - v;
43         //得到速度小于等于a[i].second的点的个数和x总和
44         long long sum1 = getSum(s1, pos), sum2 = getSum(s2, pos);
45         ans += sum1 * a[i].first - sum2;    //num * x[i] - num个x[j]的和。(对所有的x[j]<x[i])
46         add(pos, a[i].first);
47     }
48     printf("%lld", ans);
49     return 0;
50 }   

猜你喜欢

转载自www.cnblogs.com/zsh-notes/p/12374411.html