51Nod 1322 - 关于树的函数(树DP)

【题目描述】
在这里插入图片描述
【思路】
看了大佬的题解才想明白的,f_zyj大佬的题解
两棵树,对第一棵树暴力枚举所有边,拆掉这条边后的两个子树对应两个集合 A 1 , B 1 A1,B1 ,用 d f s dfs 枚举,然后在枚举出某一个 A 1 , A 2 A1,A2 时,所有在 A 1 A1 中的节点 u u u s e d [ u ] = t r u e used[u]=true ,现在对第二棵树枚举, d f s 2 dfs2 枚举的时候和刚才 d f s 1 dfs1 不同,这回是把节点 u u u u 的所有子孙看成集合 B 1 B1 ,树上的其它节点看成是集合 B 2 B2 ,这样一来,可以递推的计算集合中元素的个数已经 A 1 , B 1 A1,B1 交集的大小,设第二棵树上节点 u u 对应的集合大小为 n u m [ u ] num[u] ,和 A 1 A1 的交集大小为 d p [ u ] dp[u] ,如果 u u 的所有儿子节点所在集合为 S S ,那么就有 n u m [ u ] = 1 + v S n u m [ v ] num[u]=1+\sum_{v \in S}num[v] d p [ u ] = { 1 + v S d p [ v ]     ( u s e d [ u ] = t r u e )         v S d p [ v ]     ( u s e d [ u ] = f a l s e ) dp[u]=\begin{cases} 1+\sum_{v \in S}dp[v] \ \ \ (used[u]=true) \\ \ \ \ \ \ \ \ \sum_{v \in S}dp[v] \ \ \ (used[u]=false) \end{cases}
而且只要知道 A 1 , B 1 A1,B1 的交集大小,并且 A 1 A 2 A1,A2 交集为空, B 1 , B 2 B1,B2 交集为空,因此其余三对集合的交集大小也能推算出来,取一下最大值,不过题解里的“树归”是个啥?树上递归吗?不是很懂…

#include<bits/stdc++.h>
#define max(a,b)(a>b?a:b)
using namespace std;
const int maxn=4005;

struct Edge{
	int from,to;
	Edge(int f=0,int t=0):from(f),to(t){}
};

int n,a1;
long long ans;
bool used[maxn];
int num[maxn],dp[maxn];
vector<int> g1[maxn],g2[maxn];
Edge edges[maxn];

void dfs1(int u,int fa){
	used[u]=true;
	++a1;
	for(int i=0;i<g1[u].size();++i){
		int v=g1[u][i];
		if(v!=fa && !used[v]) dfs1(v,u);
	}
}

void dfs2(int u,int fa){
	num[u]=1;
	dp[u]=used[u]?1:0;
	for(int i=0;i<g2[u].size();++i){
		int v=g2[u][i];
		if(v!=fa){
			dfs2(v,u);
			num[u]+=num[v];
			dp[u]+=dp[v];
		}
	}
	if(u!=0){//u=0时树没有被分成两部分所以不算 
		int b1=num[u];
		//集合a1和b1的交集大小为dp[u]
		int maxv=0;
		maxv=max(maxv,dp[u]);
		maxv=max(maxv,a1-dp[u]);
		maxv=max(maxv,b1-dp[u]);
		maxv=max(maxv,n-a1-b1+dp[u]);
		ans+=(long long)maxv*maxv;
	}
}

int main(){
	scanf("%d",&n);
	for(int i=0;i<n-1;++i){
		int u,v;
		scanf("%d%d",&u,&v);
		g1[u].push_back(v);
		g1[v].push_back(u);
		edges[i]=Edge(u,v);
	}
	for(int i=0;i<n-1;++i){
		int u,v;
		scanf("%d%d",&u,&v);
		g2[u].push_back(v);
		g2[v].push_back(u);
	}
	for(int i=0;i<n-1;++i){
		a1=0;
		memset(used,0,sizeof(used));
		dfs1(edges[i].from,edges[i].to);
		dfs2(0,-1);
	}
	printf("%lld\n",ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xiao_k666/article/details/83867051