bzoj1812 [IOI2005]riv河流

题目链接

problem

给出一棵树,每个点有点权,每条边有边权。0号点为根,每个点的代价是这个点的点权\(\times\)该点到根路径上的边权和。
现在可以选择最多K个点。使得每个点的代价变为:这个点的点权\(\times\)改点到最近的被选中的一个祖先的边权和。
问所有点的代价和最小为多少。

solution

\(g[i][j]\)表示以i为根的子树,强制选i,最大的贡献(这里的贡献是指比什么也不选所减少的代价。)

最终答案肯定就是初始代价-g[0][k]

考虑怎么维护出\(g\)。用\(f[i][j]\)表示以\(i\)为根的子树,\(i\)可选可不选。然后树形背包一下就可以求出g。

考虑怎么维护f。每当枚举到一个根的时候,就重新dfs一遍这棵子树,初始f[x][0]=w[x]*dep[u]。dep[u]表示从枚举的根到0号点的距离。然后同样方法背包一遍,就可以维护处\(f\)

把j写成k调了一上午。。。。/自闭

code

/*
* @Author: wxyww
* @Date:   2019-12-21 10:08:12
* @Last Modified time: 2019-12-21 11:04:12
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<ctime>
using namespace std;
typedef long long ll;
const int N = 110;
ll read() {
    ll x = 0,f = 1;char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') f = -1; c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = x * 10 + c - '0'; c = getchar();
    }
    return x * f;
}
int siz[N],f[N][55],g[N][N],dep[N],w[N],n,K;
struct node {
    int v,nxt;
}e[N];
int head[N],ejs;
void add(int u,int v) {
    e[++ejs].v = v;e[ejs].nxt = head[u];head[u] = ejs;
}
void dp(int u,int W) {
    f[u][0] = W * w[u];
    for(int i = head[u];i;i = e[i].nxt) {
        int v = e[i].v;
        dp(v,W);
        for(int j = min(K,siz[u]);j >= 0;--j) {
            for(int k = 0;k <= min(j,siz[v]);++k) {
                f[u][j] = max(f[u][j],f[v][k] + f[u][j - k]);
            }
        }
    }
    for(int i = 1;i <= K;++i) f[u][i] = max(f[u][i],g[u][i]);//在算上强制选的答案
}
void dfs(int u) {
    siz[u] = 1;
    for(int i = head[u];i;i = e[i].nxt) {
        dep[e[i].v] += dep[u];
        dfs(e[i].v);
        siz[u] += siz[e[i].v];
    }

    g[u][1] = dep[u] * w[u];
    
    memset(f,0,sizeof(f));
    
    for(int i = head[u];i;i = e[i].nxt) {
        int v = e[i].v;
        dp(v,dep[u]);
        for(int j = min(K,siz[u]);j >= 1;--j) {
            for(int k = 0;k < j;++k) {
                g[u][j] = max(g[u][j],g[u][j - k] + f[v][k]);
            }
        }
    }
    // if(u == 1) cout<<g[1][1]<<endl;
}
int main() {
    // freopen("1.in","r",stdin);
    n = read(),K = read();
    ++K;
    for(int i = 1;i <= n;++i) {
        w[i] = read();int u = read();add(u,i);
        dep[i] = read();
    }
    dfs(0);
    int ans = 0;
    for(int i = 1;i <= K;++i) ans = max(ans,g[0][i]);
    // cout<<g[2][1];
    for(int i = 1;i <= n;++i) ans -= dep[i] * w[i];
    cout<<-ans<<endl;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/wxyww/p/bzoj1812.html