版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/chr1991/article/details/80945481
原题链接:
Kingdom Division
- 由于树的层次可能很深,所以这里不能使用递归版的DFS。我使用了BFS。
- BFS确定各结点的父结点和它的孩子数。
- 用逆拓扑排序确定结点的计算顺序。
- same[u][0] 表示u结点颜色为0孩子结点颜色全为1时组合数。 diff[u][0] 表示u结点颜色为0时可行组合数。本结点颜色为0,子结点颜色为1,孙结点颜色全为0是无效组合。反之亦然。由于这里颜色0、1相互对称, same[u][0]=same[u][1]; diff[u][0]=diff[u][1]; 。
- 为排除无效组合,计算 diff[u][0] 时没有乘以 same[u][1] 。
- 本结点为0,孩子结点全为1,是无效组合,因此 diff[u][0] 最后还要减去 same[u][0]
#include <cmath>
#include <cstdio>
#include <cassert>
#include <vector>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int MOD = 1e9+7;
int main() {
/* Enter your code here. Read input from STDIN. Print output to STDOUT */
int n = 0; assert(1 == scanf("%d", &n));
vector<vector<int>> adjList(n + 1);
for(int i = 0; i < n - 1; i++){
int u, v; assert(2 == scanf("%d %d", &u, &v));
adjList[u].push_back(v);
adjList[v].push_back(u);
}
vector<vector<long long>> diff(n + 1, vector<long long>(2, 1));
vector<vector<long long>> same(n + 1, vector<long long>(2, 1));
vector<int> parent(n + 1, 0);
vector<int> child_cnt(n + 1, 0);
queue<int> traversal_array;
vector<bool> visited(n + 1, false);
traversal_array.push(1);
visited[1] = true;
while(!traversal_array.empty()){
int tmp = traversal_array.front();
traversal_array.pop();
for(auto vit: adjList[tmp]){
if(!visited[vit]){
visited[vit] = true;
child_cnt[tmp]++;
parent[vit] = tmp;
traversal_array.push(vit);
}
}
}
for(int c_i = 1; c_i <= n; c_i++){
if(!child_cnt[c_i]){
traversal_array.push(c_i);
}
}
while(!traversal_array.empty()){
int tmp = traversal_array.front();
traversal_array.pop();
child_cnt[parent[tmp]]--;
if(child_cnt[parent[tmp]] == 0){
traversal_array.push(parent[tmp]);
}
for(auto vit: adjList[tmp]){
if(vit != parent[tmp]){
same[tmp][1] = same[tmp][1] * diff[vit][0] % MOD;
same[tmp][0] = same[tmp][0] * diff[vit][1] % MOD;
}
}
for(auto vit: adjList[tmp]){
if(vit != parent[tmp]){
diff[tmp][1] = diff[tmp][1] * (diff[vit][0] + diff[vit][1] + same[vit][1]) % MOD;
diff[tmp][0] = diff[tmp][0] * (diff[vit][0] + diff[vit][1] + same[vit][0]) % MOD;
}
}
diff[tmp][0] = (diff[tmp][0] - same[tmp][0] + MOD) % MOD;
diff[tmp][1] = (diff[tmp][1] - same[tmp][1] + MOD) % MOD;
}
printf("%lld\n", (diff[1][0] + diff[1][1])%MOD);
return 0;
}