D - Sum of Maximum Weights
Time Limit: 2 sec / Memory Limit: 1024 MB
Score : 400 points
Problem Statement
We have a tree with N N N vertices numbered 1 , 2 , … , N 1,2,…,N 1,2,…,N.
The i i i-th edge ( 1 ≤ i ≤ N − 1 1≤i≤N−1 1≤i≤N−1) connects Vertex u i u_i ui and Vertex v i v_i vi and has a weight
w i w_i wi.
For different vertices u u u and v v v, let f ( u , v ) f(u,v) f(u,v) be the greatest weight of an edge contained in the shortest path from Vertex u u u to Vertex v v v.
Find
∑ i = 1 N − 1 ∑ j = i + 1 N f ( i , j ) \sum_{i=1}^{N-1} \sum_{j=i+1}^{N}f(i,j) i=1∑N−1j=i+1∑Nf(i,j)
Constraints
- 2 ≤ N ≤ 1 0 5 2≤N≤10^5 2≤N≤105
- $1≤u_i,v_i≤N $
- 1 ≤ w i ≤ 1 0 7 1≤w_i≤10^7 1≤wi≤107
- The given graph is a tree.
- All values in input are integers.
Input
Input is given from Standard Input in the following format:
N
u 1 v 1 w 1 u_1 \quad v_1\quad w_1 u1v1w1
⋮
u N − 1 v N − 1 w N − 1 u_N−1\quad v_N−1\quad w_N−1 uN−1vN−1wN−1
Output
Print the answer.
Sample Input 1
3
1 2 10
2 3 20
Sample Output 1
50
We have f ( 1 , 2 ) = 10 f(1,2)=10 f(1,2)=10, f ( 2 , 3 ) = 20 f(2,3)=20 f(2,3)=20, and f ( 1 , 3 ) = 20 f(1,3)=20 f(1,3)=20, so we should print their sum, or 50 50 50.
Sample Input 2
5
1 2 1
2 3 2
4 2 5
3 5 14
Sample Output 2
76
思路
每条边至少会被计算一次(即两个点之间直接连接来计算的)
所以我们得到结果需要计算对于每一条边,会有多少个点对代价是这条边,就是这条边左边连的集合中的点数乘以右边连的集合中的点数。之后可以删掉这个边,然后看其中剩下他们最大的边。继续这样分析?但是每次都需要维护出最大边两端连的点数,而且每个点连的点数在变化,直接爆搜一定会超时。
对于删边的操作我们可以变成加边来做。这样我们就能用并查集维护出一条边上两端点数了。删边是选最大边,那么加边就是选最小边开始加。
将每个边从小到大排列,然后用并查集一个个合并,对于一个集合里的所有的点到另一个集合的路径中最大的就是当前正要合并的边,因为该合并方法合并的情况下,每个集合中的边的权值都小于连接两个集合的边(即当前所要合并的边)。
代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
struct node{
int a, b, w;
bool operator <(const node& b)const{
return w < b.w;
}
}ed[N];
int f[N], sz[N];
int find(int x){
return f[x] = (f[x] == x ? x : find(f[x]));
}
int main(){
int n;
cin >> n;
for(int i = 1; i < n; i++){
cin >> ed[i].a >> ed[i].b >> ed[i].w;
}
for(int i = 1; i <= n; i++) f[i] = i, sz[i] = 1;
sort(ed + 1, ed + n);
ll res = 0;
for(int i = 1; i < n; i++){
int a = find(ed[i].a)
int b = find(ed[i].b);
res += 1ll * sz[a] * sz[b] * ed[i].w;
f[a] = b, sz[b] += sz[a];
}
cout << res << "\n";
return 0;
}