HDU 6795 Little W and Contest (并查集)

题意:给定n个点,权值只有1和2,互不相连。从中选择3个点,满足3个点的权值之和不少于5,且3个点之间互不相连,计算出不同的选择方案的数量。接着加入n−1条边,每次将两个连通块相连,然后输出当前状态下,不同方案的数量。

题解:并查集
每次减去合并的两个连通块的贡献即可。

下午训练的时候因为在除之前取余wa了好多发…

#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<fstream>
#include<set>
#include<map>
#include<sstream>
#include<iomanip>
#define ll long long
using namespace std;
const int maxn = 1e5 + 5;
const int mod = 1e9 + 7;
int Parent[maxn], Rank[maxn], cnt1[maxn], cnt2[maxn];
int Find(int x) {
    
    
	return Parent[x] == x ? x : Parent[x] = Find(Parent[x]);
}
int Union(int x, int y) {
    
    
	int u, v, root;
	u = Find(x);
	v = Find(y);
	if (u == v) return 0;
	if (Rank[u] <= Rank[v]) {
    
    
		cnt1[v] += cnt1[u];
		cnt2[v] += cnt2[u];
		root = Parent[u] = v;
		if (Rank[u] == Rank[v]) Rank[v]++;
	}
	else {
    
    
		cnt1[u] += cnt1[v];
		cnt2[u] += cnt2[v];
		root = Parent[v] = u;
	}
	return root;
}
int t, n, a[maxn], u, v;
int main() {
    
    
	scanf("%d", &t);
	while (t--) {
    
    
		memset(cnt1, 0, sizeof(cnt1));
		memset(cnt2, 0, sizeof(cnt2));
		memset(Rank, 0, sizeof(Rank));
		scanf("%d", &n);
		int num1 = 0, num2 = 0;
		for (int i = 1; i <= n; i++) {
    
    
			scanf("%d", &a[i]);
			if (a[i] == 1) cnt1[i] = 1, ++num1;
			else cnt2[i] = 1, ++num2;
			Parent[i] = i;
		}
		int ini = (1ll * num2 * (num2 - 1) / 2 % mod * num1 % mod + 
			1ll * num2 * (num2 - 1) * (num2 - 2) / 6 % mod) % mod;
		printf("%d\n", ini);
		for (int i = 1; i < n; i++) {
    
    
			scanf("%d%d", &u, &v);
			int fau = Find(u), fav = Find(v);
			int yu1 = cnt1[fau];
			int yu2 = cnt2[fau];
			int yv1 = cnt1[fav];
			int yv2 = cnt2[fav];
			int rt = Union(u, v);
			if (i >= n - 2) {
    
    
				puts("0");
				continue;
			}
			int n1 = cnt1[rt];
			int n2 = cnt2[rt];
			ini -= 1ll * (n - n1 - n2) * yu2 % mod * yv2 % mod; if (ini < 0) ini += mod;
			ini -= 1ll * yu1 * yv2 % mod * (num2 - n2) % mod; if (ini < 0) ini += mod;
			ini -= 1ll * yv1 * yu2 % mod * (num2 - n2) % mod; if (ini < 0) ini += mod;
			printf("%d\n", ini);
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_43680965/article/details/107647478