[xsy2724]Tree

题意:给一棵树,找出$k$个点$A_{1\cdots k}$以最小化$\begin{align*}\sum\limits_{i=1}^{k-1}dis_{A_i,A_{i+1}}\end{align*}$

当$k=n$时,除了$A_1\rightarrow A_k$路径上的边,其他边都被经过两次,所以答案就是边权和的两倍减去直径长度

所以我们这样设计状态:$g_{i,j}$表示在$i$的子树内选了$j$个点(包括$i$)构成的树的(边权和的两倍)的最小值,$f0_{i,j}$表示在$i$的子树内选了$j$个点,并且这$j$个点构成的树的直径有一端是$i$,构成的树的(边权和的两倍减去直径长度)的最小值,$f1_{i,j}$表示在$i$的子树内选了$j$个点,并且这$j$个点构成的树的直径两端都不是$i$,构成的树的(边权和的两倍减去直径长度)的最小值,容易得到转移(假设$y$是$x$的儿子)

$g_{x,j+k}\leftarrow g_{x,j}+g_{y,k}+2w_{x,y}$

$f0_{x,j+k}\leftarrow f0_{x,j}+g_{y,k}+2w_{x,y}$(直径不变)

$f0_{x,j+k}\leftarrow g_{x,j}+f0_{y,k}+w_{x,y}$(直径变为从$y$过来)

$f1_{x,j+k}\leftarrow g_{x,j}+f1_{y,k}+2w_{x,y}$(直径变为从$y$来,不经过$x$)

$f1_{x,j+k}\leftarrow f0_{x,j}+f0_{y,k}+w_{x,y}$(直径变为从$y$来,经过$x$)

$f1_{x,j+k}\leftarrow f1_{x,j}+g_{y,k}+2w_{x,y}$(直径不变)

总的计算次数是$\begin{align*}\sum\limits_{fa_i=fa_j}size_isize_j\end{align*}$,注意到$\begin{align*}\sum\limits_{fa_i=fa_j=x}size_isize_j\end{align*}$统计的是$lca_{i,j}=x$的对数,对所有$x$求和就是统计点对,所以总时间复杂度是$O\left(n^2\right)$

#include<stdio.h>
const int inf=1000000000;
int h[3010],nex[6010],to[6010],w[6010],siz[3010],g[3010][3010],f0[3010][3010],f1[3010][3010],M;
void add(int a,int b,int c){
	M++;
	to[M]=b;
	w[M]=c;
	nex[M]=h[a];
	h[a]=M;
}
void min(int&a,int b){
	if(b<a)a=b;
}
void dfs(int fa,int x){
	int i,j,k,s;
	siz[x]=1;
	for(i=h[x];i;i=nex[i]){
		if(to[i]!=fa){
			dfs(x,to[i]);
			siz[x]+=siz[to[i]];
		}
	}
	g[x][1]=f0[x][1]=f1[x][1]=0;
	for(i=2;i<=siz[x];i++)g[x][i]=f0[x][i]=f1[x][i]=inf;
	s=1;
	for(i=h[x];i;i=nex[i]){
		if(to[i]!=fa){
			for(j=s;j>0;j--){
				for(k=1;k<=siz[to[i]];k++){
					min(g[x][j+k],g[x][j]+g[to[i]][k]+2*w[i]);
					min(f0[x][j+k],f0[x][j]+g[to[i]][k]+2*w[i]);
					min(f0[x][j+k],g[x][j]+f0[to[i]][k]+w[i]);
					min(f1[x][j+k],g[x][j]+f1[to[i]][k]+2*w[i]);
					min(f1[x][j+k],f0[x][j]+f0[to[i]][k]+w[i]);
					min(f1[x][j+k],f1[x][j]+g[to[i]][k]+2*w[i]);
				}
			}
			s+=siz[to[i]];
		}
	}
}
int main(){
	int n,k,i,x,y,z,ans;
	scanf("%d%d",&n,&k);
	for(i=1;i<n;i++){
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z);
		add(y,x,z);
	}
	dfs(0,1);
	ans=inf;
	for(i=1;i<=n;i++){
		if(siz[i]>=k){
			min(ans,f0[i][k]);
			min(ans,f1[i][k]);
		}
	}
	printf("%d",ans);
}

猜你喜欢

转载自www.cnblogs.com/jefflyy/p/8909531.html