ZOJ-3949 Edge to the Root(树形dp)

题目链接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3949

题目大意:有一棵以结点1为根节点且边权值为1的树,现在你可以从结点1向树中的某一个点x连一条边。现在要使得树上除根节点1以外的点到根节点1的距离和最小,问结点1应该和哪个结点连边。

题目思路:通过画图,我们可以知道结点1和结点x连边之后,只会对结点1到结点x的链上的点及其子树产生影响,同时只会影响到深度为dep[x]/2 ~ dep[x]的结点及其子树。

那么连边之后,树上的结点要到达结点1的最短路就会分为两类,一类是直接前往结点1,另一类是先走到结点x再走到结点1。我们可以通过两遍dfs处理出以下的一些信息:

sz[u]:结点u的子树大小;

dep[u]:结点u的深度;

fa[u][i]:结点u的祖先;

d1[u]:结点u的子树中所有结点到达结点u的距离之和;

d2[u]:树上任意一个结点到达结点u的距离之和;

前面4个都是很容易求的,现在讲一下d2[u]该如何求。显然当u=1时,d2[u]=d1[u];当u != 1时,我们就可以借助u的父亲结点来推出d2[u]的值,d2[u] =(d2[fa]-sz[u]*1)+(sz[1]-sz[u])*1,我们已经知道d2[fa]表示树上任意一点到达结点fa的距离之和,由于现在是要求d2[u],那么结点u的子树中的点就不必再走到结点fa了,就减少了sz[u]*1的距离;但是除结点u的子树以外的点需要走到fa之后,再走到结点u,所以就得再增加(sz[1]-sz[u])*1的距离。这样就能求出d2[u]的值了。

现在预处理完这些值,接下来就能对答案进行求解了。上面说了,树上的结点要到达结点1的最短路就会分为两类,一类是直接前往结点1,第二类是先走到结点x再走到结点1。且只有深度在dep[x]/2 ~ dep[x]的结点及其子树会是第二类情况,现在假设向结点x连边之后会影响到的最上面的点为par。

那么第一类的点对答案的贡献就是res1=d2[1]-d1[par]-sz[par]*dep[par],表示除了par及其子树外的结点都直接前往结点1的距离之和。

第二类点对答案的贡献就是res2=d2[x] - (d2[par]-d1[par]+(n-sz[par])*dis)+sz[par],dis表示结点par到结点x的距离。这个式子表示par的子树内所有的点到达结点x的距离之和再加上通过结点1到结点x的边到达结点1的距离之和。

那么结点1向结点x连边之后,所有结点到达结点1的最小距离之和为res1+res2

par我们可以通过类似求lca的方法倍增求出来,剩下的部分在前面就预处理好了,只需要O(1)即可求出,所以我们就可以枚举x求出最终的答案了。

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
typedef vector<int> VI;
const int inf = 0x3f3f3f3f;
const double eps = 1e-6;
const int MX = 2e5 + 5;

int n, _;
struct edge {
	int v, nxt;
} E[MX << 1];
int head[MX], tot;
int sz[MX], dep[MX], fa[MX][22];
ll d1[MX], d2[MX];
void init() {
	clr(d1); clr(d2);
	memset(head, -1, sizeof(head));
	tot = 0;
}
void add_edge(int u, int v) {
	E[tot].v = v; E[tot].nxt = head[u];
	head[u] = tot++;
}
void dfs1(int u, int pre, int d) {
	dep[u] = d;
	sz[u] = 1; fa[u][0] = pre;
	for (int i = 1; i <= 20; i++)
		fa[u][i] = fa[fa[u][i - 1]][i - 1];
	for (int i = head[u]; ~i; i = E[i].nxt) {
		int v = E[i].v;
		if (v == pre) continue;
		dfs1(v, u, d + 1);
		d1[u] += d1[v] + sz[v];
		sz[u] += sz[v];
	}
}
void dfs2(int u, int pre) {
	if (u == 1) d2[u] = d1[u];
	else d2[u] = d2[pre] - sz[u] + (sz[1] - sz[u]);
	for (int i = head[u]; ~i; i = E[i].nxt) {
		int v = E[i].v;
		if (v == pre) continue;
		dfs2(v, u);
	}
}
int Find(int x, int dis) {
	for (int i = 20; i >= 0; i--) {
		if ((dis >> i) & 1) x = fa[x][i];
	}
	return x;
}

int main() {
	//FIN;
	for (scanf("%d", &_); _; _--) {
		scanf("%d", &n);
		init();
		for (int i = 1; i < n; i++) {
			int u, v;
			scanf("%d%d", &u, &v);
			add_edge(u, v); add_edge(v, u);
		}
		dfs1(1, 0, 0); dfs2(1, 0);
		ll ans = d1[1];
		for (int i = 2; i <= n; i++) {
			int dis = (dep[i] - dep[i] / 2 - 1);
			int par = Find(i, dis);
			ll res = d2[par] - d1[par] + 1ll * (n - sz[par]) * dis;
			ll sub_dis = d2[i] - res + sz[par];
			ll all_dis = d2[1] - d1[par] - 1ll * sz[par] * dep[par];
			ans = min(ans, sub_dis + all_dis);
		}
		cout << ans << endl;
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Lee_w_j__/article/details/82146830