树形DP-Apple Tree

版权声明:未经博主同意,不可转载 https://blog.csdn.net/pythonbanana/article/details/88187930

题目传送门
(详细思路)思路参见:https://www.cnblogs.com/fightfordream/p/6653890.htm
这道题想了好久才想明白。。。。。。做这种难度的题对自己还是很有收获的。
/*
个人理解 :首先 这道题要每个节点的权值是x的倍数。因为每个节点都要相等,所以有倍数关系,那么
当前节点u的权值必然是其子节点倍数的最小公倍数。比如其子节点为2n == 4m == 3k;而sum = 2n+3k+4m
所以每个子节点的权值是lcm(2,3,4)的最小公倍数的倍数。这道题用的就是cnt数组记录
还要记录每个节点当前的最大权值。所以用mx。
这道题递归求解,所以分析问题时,只要考虑当前节点即可(举一个3层的树分析),其他的递归求解。
*/

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<set>
#include<stack>
#include<queue>
#include<map>
#include<cstring>
#include<string>
#include<cmath>

using namespace std;

typedef long long LL;

#define INF 0x3f3f3f3f
#define PI acos(-1.0)
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define mem(a,b) memset(a,b,sizeof(a))
#define per(i,a,b) for(int i = a;i <= b;++i)
#define rep(i,a,b) for(int i = a;i >= b;--i)
const int maxn = 1e5;
int n = 0,m = 0;
int w[maxn+10];
vector<int> vt[maxn+10];
LL mx[maxn+10];//mx[i]表示节点i权值的最大值 
LL cnt[maxn+10];//cnt[i]表示节点i的权值必须是cnt[i]的倍数 
/*
个人理解 :首先 这道题要每个节点的权值是x的倍数。因为每个节点都要相等,所以有倍数关系,那么
当前节点u的权值必然是其子节点倍数的最小公倍数。比如其子节点为2n == 4m == 3k;而sum = 2n+3k+4m
所以每个子节点的权值是lcm(2,3,4)的最小公倍数的倍数。这道题用的就是cnt数组记录 
还要记录每个节点当前的最大权值。所以用mx。
这道题递归求解,所以分析问题时,只要考虑当前节点即可(举一个3层的树分析),其他的递归求解。 
*/
LL gcd(LL a,LL b){
	return b == 0 ? a : (gcd(b,a%b));
}
LL lcm(LL a,LL b){
	return a/gcd(a,b) * b;
}
void dfs(int curr,int fa){
	int num = 0;
	int size = vt[curr].size();
	for(int i = 0;i <= size-1;++i){
		int v = vt[curr][i];
		if(v == fa){
			continue;
		}
		dfs(v,curr);
		if(num == 0){//刚开始赋值 
			mx[curr] = mx[v];
				cnt[curr] = cnt[v];	
		}else{
			if(cnt[curr] < 1e14){//有溢出的风险 
				cnt[curr] = lcm(cnt[curr],cnt[v]);
			}	
			mx[curr] = min(mx[curr],mx[v]) / cnt[curr] * cnt[curr];
		}
		++num;
	}
	if(num == 0){
		mx[curr] = w[curr];	cnt[curr] = 1;
	}else{
		mx[curr] *= num;	
		if(cnt[curr] < 1e14){
			cnt[curr] *= num;
		}
	}
}
int main(){
	while(~scanf("%d",&n)){
		LL ans = 0;
		per(i,1,n){
			scanf("%d",&w[i]);
			ans += w[i];
		} 
		per(i,1,n-1){
			int u = 0,v = 0;
			scanf("%d %d",&u,&v);
			vt[u].push_back(v);vt[v].push_back(u);
		}	
		dfs(1,-1);
		printf("%I64d\n",ans - mx[1]);
	}
	
	return 0;
}

猜你喜欢

转载自blog.csdn.net/pythonbanana/article/details/88187930