树上背包问题

树上背包问题

一些题目给定了树形结构,在这个树形结构中选取一定数量的点或边(也可能是其他属性),使得某种与点权或者边权相关的花费最大或者最小。解决这类问题,一般要考虑使用树上背包。

算法原理

树上背包,顾名思义,就是在树上做背包问题。一个节点的若干子树可以看作是若干组背包,也就是用树形dp的方式做分组背包问题。一般来说, f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,在 j j j的容量范围内,最大或者最小可以获得多少收益。根据分组背包的思想,第一维枚举物品(在树上指的是子树),第二维枚举容量,第三维枚举决策(这里指的是给子树分配多少容量)。基本的代码框架如下:

void dfs(int u, int fa)
{
    
    
	for(int i = h[u]; ~i; i = ne[i])
	{
    
    
		int son = e[i];
		if(son == fa) continue;
		dfs(son, u);
		for(int j = m; j >= 0; j --)
			for(int k = 0; k <= j; k ++)
				f[u][j] = max(f[u][j], f[u][j-k] + f[son][k] + val);
	}
}

例题一:有依赖的背包问题

题意

n n n个物品和一个容量是 m m m的背包。物品之间具有依赖关系,且依赖关系组成一棵树的形状。如果选择一个物品,则必须选择它的父节点。
求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。输出最大价值。
每件物品的编号是 i i i,体积是 v i v_i vi,价值是 w i w_i wi,依赖的父节点编号是 p i p_i pi。物品的下标范围是 1 … N 1 \dots N 1N

数据范围

1 ≤ n , m ≤ 100 1 \leq n,m \leq 100 1n,m100
1 ≤ v i , w i ≤ 100 1 \leq v_i,w_i \leq 100 1vi,wi100

思路

f ( i , j ) f(i,j) f(i,j)表示选择以 i i i为子树的物品,在容量不超过 j j j时所获得的最大价值。
由于只有选择了根节点,才会继续往下遍历,所以在遍历到 i i i节点时,先考虑一定选上它。
在分组背包部分, j j j的范围为 [ m , v [ i ] ] [m,v[i]] [m,v[i]],否则没有意义,因为连根节点也放不下; k k k的范围 [ 0 , j − v [ i ] ] [0,j-v[i]] [0,jv[i]],当大于 j − v [ i ] j-v[i] jv[i]时分给该子树的容量过多,剩余的容量连根节点的物品都放不下了。
递推式为: f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) ) f(i,j) = max(f(i,j), f(i,j - k) + f(son,k)) f(i,j)=max(f(i,j),f(i,jk)+f(son,k))

代码

void dfs(int u)
{
    
    
    for(int i = v[u]; i <= m; i ++) f[u][i] = w[u];
    for(int i = h[u]; ~i; i = ne[i])
    {
    
    
        int son = e[i];
        dfs(son);
        for(int j = m; j >= v[u]; j --)
            for(int k = 0; k <= j - v[u]; k ++)
                f[u][j] = max(f[u][j], f[u][j - k] + f[son][k]);
    }
}

例题二:二叉苹果树

题意

给定一棵二叉树,每条边有边权,保留一定数量的边(其他边删除),使得保留下来的边的边权和最大。

数据范围

1 ≤ n < m ≤ 100 1 \leq n < m \leq 100 1n<m100
w i ≤ 30000 w_i \leq 30000 wi30000

思路

f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,恰好保留 j j j条边的最大边权和。
若需要选择该子树中的边,则根结点到子树的边一定要选,因此能用上的总边数一定减 1 1 1,总共可以选择 j j j条边时,当前子树son分配的最大边数是 j − 1 j - 1 j1
递推式为, f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k − 1 ) + f ( s o n , k ) + w [ i ] ) f(i,j) = max(f(i,j), f(i,j-k-1) + f(son, k) + w[i]) f(i,j)=max(f(i,j),f(i,jk1)+f(son,k)+w[i])

代码

void dfs(int u, int fa)
{
    
    
    for(int i = h[u]; ~i; i = ne[i])
    {
    
    
    	int son = e[i];
        if(son == fa) continue;
        dfs(son, u);
        for(int j = m; j >= 1; j -- )
            for(int k = 0; k <= j - 1; k ++ )
                f[u][j] = max(f[u][j], f[u][j - k - 1] + f[son][k] + w[i]);
    }
}

例题三:Factories(2018icpc银川网络赛)

题意

给定一棵树,边有边权。每个叶子节点上最多可以布置一个工厂,总共要布置 k k k个工厂。问怎样布置工厂,使得工厂之间的距离和最小。

数据范围

10 s 10s 10s
2 ≤ n ≤ 1 0 5 2 \leq n \leq 10^5 2n105, 1 ≤ m ≤ 100 1 \leq m \leq 100 1m100
1 ≤ w i ≤ 1 0 5 1 \leq w_i \leq 10^5 1wi105
多组测试数据, n n n总数不超过 1 0 6 10^6 106

思路

直接考虑距离之和非常困难,所以可以考虑每条边被计算了几次(距离和等类似问题很多都是这么考虑的)。不妨设一条边为 i i i,与 i i i相连的子树中有 j j j个工厂,则这条边被计算的次数为 j ∗ ( m − j ) j*(m - j) j(mj)
f ( i , j ) f(i,j) f(i,j)表示以 i i i为根节点的子树中,选择恰好 j j j个叶子节点的距离总和。
递推式为, f ( i , j ) = m i n ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) + w [ i ] ∗ j ∗ ( m − j ) ) f(i,j) = min(f(i,j), f(i,j - k) + f(son, k) + w[i] * j * (m - j)) f(i,j)=min(f(i,j),f(i,jk)+f(son,k)+w[i]j(mj))
因为只能分布在叶子节点,因此初始化的时候要注意,如果点 i i i为叶子节点,那么 f ( i , 1 ) = 0 f(i,1) = 0 f(i,1)=0
同时这道题要卡常数,所以要对状态做一个优化,即把无效状态去掉。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N = 100003, M = 103;
const ll inf = 1e18;

int n, m;
int h[N], e[2*N], ne[2*N], w[2*N], idx;
int s[N], deg[N];
ll f[N][M];

void add(int a,int b,int c)
{
    
    
    e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}

void dfs(int u,int fa)
{
    
    
    for(int i = h[u]; ~i; i = ne[i])
    {
    
    
        int son = e[i];
        if(son == fa) continue;
        dfs(son, u);
        s[u] += s[son];
        for(int j = min(m, s[u]); j >= 1; j --)
            for(int k = 1; k <= min(j, s[son]); k ++)
                f[u][j] = min(f[u][j], f[u][j-k] + f[son][k] + (ll)w[i] * k * (m - k));
    }
}

int main()
{
    
    
    int T;
    scanf("%d", &T);
    int cas = 0;
    while(T --)
    {
    
    
        scanf("%d%d", &n,&m);
        for(int i = 1; i <= n; i ++) h[i] = -1, deg[i] = 0;
        idx = 0;
        for(int i = 0; i < n - 1; i ++)
        {
    
    
            int a,b,c;
            scanf("%d%d%d", &a,&b,&c);
            add(a,b,c), add(b,a,c);
            deg[a] ++, deg[b] ++;
        }
        for(int i = 1; i <= n; i ++)
        {
    
    
            s[i] = 0;
            for(int j = 1; j <= m; j ++) f[i][j] = inf;
            if(deg[i]==1) f[i][1] = 0, s[i] = 1;
        }
        dfs(1, -1);
        printf("Case #%d: %lld\n",++cas,f[1][m]);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43634220/article/details/108404464
今日推荐