hdu1055 Color a tree(贪心☆☆☆☆☆)

思路来源

来源①:https://blog.csdn.net/gatieme/article/details/49202739

来源②:https://www.cnblogs.com/dramstadt/p/3201984.html

题意

给你一棵树,一上来可以染根节点。

对于其他的点i,染i时必须先染i的父节点。

每个点i对应一个权值c[i],

从t=0开始染色,染i的花费=c[i]*当前时间t

现给定父子关系和根节点编号,

不妨设点i在s[i]秒被染色,染色时间为1(这里记为t[i])

即求min(\sum c[i]*s[i])

题解

显然,我们染了一个点之后,

如果它的后继点是权值最大的点,

我们立刻染权值最大的点,是最明智的选择。

这就意味着,不管何时染权值最大的点的前驱,

染完之后下一秒都该染权值最大的点,

即在染色序列中,权值最大的点是和它的前驱挨在一起的。

我们把这两个点绑定在一起,合二为一。

记点i在s[i]秒被染色,染色时间为t[i],权值为c[i],

考虑三个点的情形x y z

其中x是y的前驱,z可以直接染。

这样有两种染色方式,

c[x]+2c[y]+3c[z]①

c[z]+2c[x]+3c[y]②

在任意一种染色方式中,y都在x后染,

设第s[x]秒染x,cost为s[x]·c[x]

则(s[x]+1)秒染y,cost为(s[x]+1)·c[y],在这里t[x]=1,

如果我们将x和y绑定在一起的话直接求s[x]*(c[x]+c[y]),会少算1·c[y]

因此,我们每一次合并,就把染父亲节点所需时间*子节点的权值(这里是t[x]*c[y])加到sum里。

此外,若先不考虑随时间增长的花费,

它们都至少有c[x]+c[y]+c[z]的基础花费,我们在最初的时候把这些也加到sum里。

而选择任意一种方式,都会有y在x后染而带来的附加花费c[y]

这样最基础的花费就是c[x]+2c[y]+c[z]

比较基础花费和①、②的区别,发现一个多2*c[z],另一个多(c[x]+c[y])

假设这里c[x]+c[y]<2*c[z],

那么我们就应该选策略②,让z先染,

其实质是c[z]>(c[x]+c[y])/2,即单染z的时间比染x、y的平均时间长,

因此,我们优先染那些平均单点时间长的点。

事实上,由于c[z]大,先将z向根合并,再将(xy)合点向根合并,就达到了先染z的目的。

而实际由于先合并z的时候,根节点里只有一个点,

再合并xy的时候,根节点里有两个点,

所以对答案的贡献,大的权值*1+小的权值*2,一定比反过来更优。

这实际上,就是第一秒染z,第二秒染xy合点(由于sum里加过c[y]其实是第二秒染x第三秒染y)的等价意义。

所以,开一个结构体,代表节点/合并后的节点

记录一个c,是合并点的总权值\sum c[i]

记录一个t,是合并点的染点总时间\sum t[i]

每次遍历选择,v=c/t最大的,即平均时间最大的点开始染。

怎么叫染了这个点呢?把它和它的父节点合并。

代表染完它的祖先节点之后,立刻染这个点。

代码

#include <iostream>
#include <algorithm> 
#include <cstring>
#include <cstdio>
#include <cmath>
#include <set>
#include <map>
#include <vector>
#include <stack>
#include <queue>
#include <functional>
const double INF=0x3f3f3f3f;
const int maxn=1e5+10; 
const int mod=1e9+7;
const int MOD=998244353;
const double eps=1e-7;
typedef long long ll;
#define vi vector<int> 
#define si set<int>
#define pii pair<double,int> 
#define pi acos(-1.0)
#define pb push_back
#define mp make_pair
#define lowbit(x) (x&(-x))
#define sci(x) scanf("%d",&(x))
#define scll(x) scanf("%lld",&(x))
#define sclf(x) scanf("%lf",&(x))
#define pri(x) printf("%d",(x))
#define rep(i,j,k) for(int i=j;i<=k;++i)
#define per(i,j,k) for(int i=j;i>=k;--i)
#define mem(a,b) memset(a,b,sizeof(a)) 
using namespace std;
int n,r,sum,a,b;
struct node
{
	double v;
	int c;//c总 
	int t;//t总 
	int par;
};
node ans[1005]; 
int main()
{ 
    while(~scanf("%d%d",&n,&r)&&n+r)
    {
    	mem(ans,0);
    	sum=0;
    	rep(i,1,n)
    	{
    		scanf("%d",&ans[i].c);
    		ans[i].t=1;
    		ans[i].v=ans[i].c;
    		sum+=ans[i].c;
    	}
    	rep(i,1,n-1)
    	{
    		scanf("%d%d",&a,&b);
    		ans[b].par=a;
    	}
    	rep(i,1,n-1)//需要合并n-1次 
    	{
    		int pos=0;
    		double maxv=-1;
    		rep(j,1,n)
    		{
    			if(j==r)continue;
    			if(ans[j].v>maxv)
    			{
    				pos=j;
    				maxv=ans[j].v;
    			}
    		}
    		ans[pos].v=0;//不影响后续操作  
    		int u=ans[pos].par;
    		sum+=ans[pos].c*ans[u].t;
    		ans[u].c+=ans[pos].c;
    		ans[u].t+=ans[pos].t;
    		ans[u].v=ans[u].c*1.0/ans[u].t;
			rep(j,1,n)
    		{
    			if(ans[j].par==pos)ans[j].par=u;
    		}
    	}
    	printf("%d\n",sum);
    }
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Code92007/article/details/82936663