Description
给出一个N个节点的无根树,每条边有非负边权,每个节点有三种颜色:黑,白,灰。
一个合法的无根树满足:树中不含有黑色结点或者含有至多一个白色节点。
现在希望你通过割掉几条树边,使得形成的若干树合法,并最小化割去树边权值的和。
Input
第一行一个正整数N,表示树的节点个数。
第二行N个整数Ai,表示i号节点的颜色,0 表示黑色,1表示白色,2表示灰色。
接下来N-1行每行三个整数Xi Yi Zi,表示一条连接Xi和Yi权为Zi的边。
Output
输出一个整数表示其最小代价。
Sample Input
5
0 1 1 1 0
1 2 5
1 3 3
5 2 5
2 4 16
Sample Output
10
样例解释:
花费10的代价删去边(1, 2)和边(2, 5)。
Data Constraint
20%的数据满足N≤10。
另外30%的数据满足N≤100,000,且保证树是一条链。
100%的数据满足N≤300,000,0≤Zi≤1,000,000,000,Ai∈{0,1,2}。
赛时
蒟蒟表示连20分暴力都没打出
正解
十分显然的树型dp,设f[i][0]表示以i为根的子树,除去已经割去的部分,满足没有黑色节点的最小代价; f[i][1]则表示没有白色
节点的最小代价;f[i][2]则表示有1或 0个白色节点的最小代价。
方程如上,自行理解,不多阐述。
之后有个恶心的东东是这道题用深搜会boom,boom,boom,系统栈会爆掉,所以要人工栈(自食其力,丰衣足食)
代码
#include<cstdio>
#include<cstring>
#define N 300007
using namespace std;
const long long INF=1e17;//ans会很大,所以“无穷大”也要很大
int n,cnt,a[N],head[N],d[N],fa[N];
long long f[N][3];
bool bz[N];
struct tree{
int to,nxt;
long long w;
}e[N<<2];
void add(int u,int v,long long w){//链式前向星
e[++cnt].to=v;
e[cnt].w=w;
e[cnt].nxt=head[u];
head[u]=cnt;
}
long long min(long long x,long long y){return x<y?x:y;}
long long max(long long x,long long y){return x>y?x:y;}
//系统的min,max感觉真的很乐色啊,果然干什么还是要靠自己
void stack(){//人工栈
memset(bz,0,sizeof(bz));
int t=1,h=0;
d[t]=1;
bz[1]=1;
while(h<t){
int x=d[++h];
for(int i=head[x];i;i=e[i].nxt){
int v=e[i].to;
if(bz[v]) continue;
d[++t]=v;
bz[v]=1;
fa[v]=x;
}
}//上面先把节点的顺序确定下来
for(int j=t;j>=1;j--){//儿子知道了才能知道爸爸,所以从后往前统计答案
int x=d[j];
if(a[x]==0) f[x][0]=INF;
if(a[x]==1) f[x][1]=INF;
long long mx=-INF;
for(int i=head[x];i;i=e[i].nxt){
int son=e[i].to;
if(son==fa[x]) continue;
if(f[x][0]!=INF)
f[x][0]+=min(f[son][0],e[i].w+min(f[son][1],f[son][2]));
if(f[x][1]!=INF)
f[x][1]+=min(f[son][1],e[i].w+min(f[son][0],f[son][2]));
if(a[x]==1)
f[x][2]+=min(f[son][1],e[i].w+min(f[son][0],f[son][2]));
else mx=max(mx,min(f[son][1],e[i].w+min(f[son][0],f[son][2]))-f[son][2]);
}
if(a[x]!=1)
f[x][2]=f[x][1]-mx;
//上面是方程部分,不讲
}
}
int main(){
freopen("tree2.in","r",stdin);
freopen("tree2.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<n;i++){
int u,v;long long w;
scanf("%d%d%lld",&u,&v,&w);
add(u,v,w),add(v,u,w);
}
stack();
long long ans=min(f[1][0],min(f[1][1],f[1][2]));
printf("%lld",ans);
}