【题解】CF915F:木の不均衡値

原题传送门
∑ i = 1 n ∑ j = 1 n I ( i , j ) \sum_{i=1}^{n}\sum_{j=1}^{n}I(i,j) Σi = 1n個Σj = 1n個i j
= ∑ i = 1 n ∑ j = 1 n M ax(i、j)− ∑ i = 1 n ∑ j = 1 n M in(i、j)= \ sum_ {i = 1} ^ {n} \ sum_ {j = 1} ^ {n} Max(i、j)-\ sum_ {i = 1} ^ {n} \ sum_ {j = 1} ^ {n} Min(i、j)=Σi = 1n個Σj = 1n個M a x i j Σi = 1n個Σj = 1n個M i n i j
2つの部分の値を個別に計算し、ユニオンチェックセットを考慮して
マージします。エッジの場合、Max(i、j)Max(i、j)を計算する場合、このエッジには2つのポイントに接続された2つのウェイトがあります。M a x i j),那么这条边的权值是两个点中大的那个权值,按照边的权值从小到大排序,然后并查集合并,同时记录两个连通块分别各有多少点
对于求 M a x ( i , j ) Max(i,j) Max(i,j)的情况,因为从小到大枚举,保证这条即将连上去的边是当前最大的,所以当前所有有关的路径上的最大值都是自身,所以直接用乘法原理计算就好了

M i n ( i , j ) Min(i,j) Min(i,j)同理,边的权值是两个点中权值较小的那个,然后从大到小排序,差不多的再做一遍

Code:

#include <bits/stdc++.h>
#define maxn 1000010
#define LL long long
using namespace std;
struct Line{
    
    
	int x, y;
	LL l1, l2;
}line[maxn];
LL num[maxn];
int f[maxn], n, a[maxn];

inline int read(){
    
    
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

bool cmp1(Line x, Line y){
    
     return x.l2 < y.l2; }
bool cmp2(Line x, Line y){
    
     return x.l1 > y.l1; }
int getfa(int k){
    
     return f[k] == k ? k : f[k] = getfa(f[k]); }

int main(){
    
    
	n = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	for (int i = 1; i < n; ++i){
    
    
		int x = read(), y = read();
		line[i].x = x, line[i].y = y;
		line[i].l1 = min(a[x], a[y]), line[i].l2 = max(a[x], a[y]);
	}
	LL ans = 0;
	sort(line + 1, line + n, cmp1);
	for (int i = 1; i <= n; ++i) f[i] = i, num[i] = 1;
	for (int i = 1; i < n; ++i){
    
    
		int x = line[i].x, y = line[i].y, s1 = getfa(x), s2 = getfa(y);
	//	printf("%d %d %lld\n", x, y, line[i].l2);
		if (s1 == s2) continue;
		ans += num[s1] * num[s2] * line[i].l2;
		f[s2] = s1, num[s1] += num[s2];
	}
	sort(line + 1, line + n, cmp2);
	for (int i = 1; i <= n; ++i) f[i] = i, num[i] = 1;
	for (int i = 1; i < n; ++i){
    
    
		int x = line[i].x, y = line[i].y, s1 = getfa(x), s2 = getfa(y);
	//	printf("%d %d %lld\n", x, y, line[i].l1);
		if (s1 == s2) continue;
		ans -= num[s1] * num[s2] * line[i].l1;
		f[s2] = s1, num[s1] += num[s2];
	}
	printf("%lld\n", ans);
	return 0;
}

おすすめ

転載: blog.csdn.net/ModestCoder_/article/details/108523040