Codeforces Round #665 (Div. 2) D. Maximum Distributed Tree 题解(贪心+易错)

题目链接

题目大意

给你一课树,要你给每一条边分权值,每条边的权值大于0,他们的乘积等于k,而且要使得n-1条边1的数量尽可能少,定义
f(u,v)为u到v的边权和求 max i = 1 i = n j = 1 j = n f ( i , j ) \max \sum_{i=1}^{i=n}\sum_{j=1}^{j=n} f(i,j)
k为m个质因子的乘积

题目思路

这显然是一个求贡献的裸题,但是里面有易错点
1:sort前不要先取模
2:还有要区分m可能比n-1大(太坑了

代码

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<unordered_map>
#define fi first
#define se second
#define debug printf(" I am here\n");
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int maxn=1e5+5,inf=0x3f3f3f3f,mod=1e9+7;
const double eps=1e-10;
int n, m, a[maxn];
ll sz[maxn], p[maxn];
int head[maxn],cnt;
struct node{
    int to, next;
}e[maxn<<1];
void add(int u,int v){
    e[++cnt] = {v, head[u]};
    head[u] = cnt;
}
void dfs(int son,int fa){
    sz[son] = 1;
    for (int i = head[son]; i;i=e[i].next){
        if(e[i].to==fa) continue;
        dfs(e[i].to,son);
        sz[son] += sz[e[i].to];
    }
}
void init(){
    cnt = 0;
    for (int i = 1; i <= n;i++){
        head[i] = sz[i] = 0;
    }
}
signed main(){
    int _;scanf("%d", &_);
    while(_--){
        scanf("%d",&n);
        init();
        for (int i = 1,u,v; i <= n - 1;i++){
            scanf("%d%d", &u, &v);
            add(u, v), add(v, u);
        }
        scanf("%d", &m);
        for (int i = 1; i <= m;i++){
            scanf("%lld", &p[i]);
        }
        sort(p + 1, p + 1 + m);//从大到小
        reverse(p + 1, p + 1 + m);
        dfs(1,1);
        for (int i = 1; i <= n;i++){//先不要取模
            sz[i] = (sz[i]) * (n - sz[i]);
        }
        sort(sz + 1, sz + 1 + n);//从大到小
        reverse(sz + 1, sz + 1 + n);
        ll ans = 0;
        if(n-1>=m){
            for (int i = 1; i <= n-1;i++){
                if(i<=m){
                    ans =(ans+ sz[i]%mod * p[i])%mod;
                }else{
                    ans =(ans+ sz[i])%mod;
                }
            }   
        }else{
            for (int i = 2; i <= m;i++){
                if(i<=m-n+2){
                     p[1] = p[1] * p[i]%mod;
                }else{
                    p[i-(m-n+2)+1] = p[i];
                }
            }
            for (int i = 1; i <= n-1; i++){
                ans = (ans + sz[i] * p[i]) % mod;
            }
        }
        printf("%lld\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/m0_46209312/article/details/108442261