ACWING115. 给树染色(贪心)

一颗树有 n 个节点,这些节点被标号为:1,2,3…n,每个节点 i 都有一个权值 A[i]。

现在要把这棵树的节点全部染色,染色的规则是:

根节点R可以随时被染色;对于其他节点,在被染色之前它的父亲节点必须已经染上了色。

每次染色的代价为T*A[i],其中T代表当前是第几次染色。

求把这棵树染色的最小总代价。

输入格式
第一行包含两个整数 n 和 R ,分别代表树的节点数以及根节点的序号。

第二行包含 n 个整数,代表所有节点的权值,第 i 个数即为第 i 个节点的权值 A[i]。

接下来n-1行,每行包含两个整数 a 和 b ,代表两个节点的序号,两节点满足关系: a 节点是 b 节点的父节点。

除根节点外的其他 n-1 个节点的父节点和它们本身会在这 n-1 行中表示出来。

同一行内的数用空格隔开。

输出格式
输出一个整数,代表把这棵树染色的最小总代价。

数据范围
1≤n≤1000,
1≤A[i]≤1000
输入样例:
5 1
1 2 1 2 4
1 2
1 3
2 4
3 5
输出样例:
33

思路:
可以确定的思路是除根外全局最大的点如果能取是一定要取的。
对于两条能取的数链a1,a2,a3,a4… an 和 b1,b2,b3,b4…bm.
先取链1再取链2的花费是:
( k + 1 ) a 1 + . . . + ( k + n ) a n + ( k + n + 1 ) b 1 + . . . + ( k + n + m ) b m (k+1)*a1 +...+(k+n)*an+(k+n+1)*b1+...+(k+n+m)*bm

先取链2再取链1的花费是:
( k + 1 ) b 1 + . . . + ( k + m ) b m + ( k + m + 1 ) a 1 + . . . + ( k + m + n ) a n (k+1)*b1+...+(k+m)*bm+(k+m+1)*a1+...+(k+m+n)*an

作差得: a / n b / m ∑a/n-∑b/m
那么链条可以合并成一个数,权值为总和除以数量。

那么我们从最大权值的点开始合并,依靠链表得到取数的顺序(维护一条链的末尾位置)。最后从根节点开始遍历获得花费。

使用了map优化寻找最大节点的过程,复杂度 O(nlogn)。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>

using namespace std;

typedef long long ll;
const int maxn = 1005;
int vis[maxn],num[maxn],fa[maxn];
int lst[maxn],nex[maxn];//nex表示这个节点对应的下一个节点,lst表示这个链表最后一个节点
double c[maxn],val[maxn],V[maxn];;

struct Node
{
    int id;
    int v,_num;
    Node(){}
    Node(int id,int v,int _num): id(id),v(v),_num(_num){}
    bool operator < (const Node &rhs)const
    {
        if(v * rhs._num == rhs.v * _num)return id < rhs.id;
        return v * rhs._num < rhs.v * _num;//如果没有重载,map就会忽视掉结构题里的这个数
    }
};
map<Node,int>mp;

int main()
{
    int n,r;scanf("%d%d",&n,&r);
    
    for(int i = 1;i <= n;i++)
    {
        scanf("%lf",&c[i]);
        V[i] = val[i] = c[i];
        if(i != r)
        {
            mp[Node(i,(int)c[i],1)] = 1;
        }
        num[i] = 1;
        lst[i] = i;nex[i] = i;
    }
    for(int i = 1;i < n;i++)
    {
        int x,y;scanf("%d%d",&x,&y);
        fa[y] = x;
    }
    
    map<Node,int>::iterator it;
    for(int i = 1;i < n;i++)
    {
        it = mp.end();
        --it;Node tmpk = it -> first;
        mp.erase(it);
        
        int k = tmpk.id;
        int f = fa[k];
        while(vis[f])
        {
            f = fa[f];fa[k] = f;
        }
        
        nex[lst[f]] = k;
        lst[f] = lst[k];

        if(f != r)
        {
            it = mp.find(Node(f,(int)val[f],num[f]));
            mp.erase(it);
        }
        num[f] += num[k];
        val[f] += val[k];
        c[f] = val[f] / num[f];
        if(f != r)
            mp[Node(f,(int)val[f],num[f])] = 1;
        vis[k] = 1;
    }
    
    ll ans = 0;
    
    for(int i = 1;i <= n;i++)
    {
        ans += (ll)V[r] * i;
        r = nex[r];
    }
    printf("%lld\n",ans);
    return 0;
}


发布了756 篇原创文章 · 获赞 27 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/tomjobs/article/details/104525394
115