题意
给定一颗点带权无根树,请你选定一个根并对这棵树进行深度优先遍历,得到一个点的经过顺序(即\(dfs\)序):\(v_1,v_2...v_n\),记点\(u\)的点权为\(A_u\)
请最小化下面式子的值
\[ \sum_{i=1}^n i\times A_{v_i} \]
解法
大佬们都说这题是煞笔题。。我还是太菜了
设当前点为\(u\),我们已经求得了\(u\)所有儿子的最优答案
现在我们要对\(u\)的所有子树排序(即确定遍历顺序),使得\(u\)的答案最优
对于这类安排顺序求最优解的问题,我们可以想到邻项交换排序
考虑\(u\)的两个儿子\(v,w\),它们是两个相邻的子树,我们现在考虑交换它们的遍历顺序能否使答案更优
首先我们把当前问题看做是一个重叠子问题,即\(u\)即是当前的根(如果当前根非\(u\),我们可以用\(u\)的答案加上\(sum_u \times (size_{pre_u} +1)\)得到新答案,这就成为一个重复问题了)
设之前子树的答案为\(ans_{pre}\),这显然是一个定值(在当前情况下)
若\(u\)在\(v\)之前
\[ ans=A[u]+ans_{pre}+(ans_u+sum_u\times (size_{pre}+1))+(ans_v + sum_v\times (size_{pre}+size_u+1)) \]
虽然这个式子看上去很复杂,但实际上有很多定值
我们把它进行化简
\[ ans=A[u]+ans_{pre}+ans_u+ans_v+sum_u+sum_v+(sum_u+sum_v)\times size_{pre}+sum_v\times size_u \]
这样如果我们把\(u\)和\(v\)进行交换,影响答案的改变只有\(sum_v\times size_u\)
那么我们排序的依据就成了最小化\(sum_v\times size_u\)
依照这个排序方法,我们有了一个\(O(N^2\log N)\)的枚举根的算法
但是我们发现如果枚举根重新\(DP\)出所有信息会有许多重复信息没有得到利用
此时我们可以采用换根\(DP\)的方式:每次换根后将其原来父亲当做其儿子进行处理
具体的排序可以用\(multiset\)来实现
代码
#include <set>
#include <set>
#include <cstdio>
#include <climits>
using namespace std;
#define int long long
int read();
const int N = 2e5 + 10;
struct node {
int id, sz, su;
bool operator < (const node& rhs) const { return sz * rhs.su < su * rhs.sz; }
};
int n;
int val[N], sum[N], siz[N], fa[N];
int cap;
int head[N], to[N << 1], nxt[N << 1];
long long ans = LLONG_MAX;
long long f[N];
multiset<node> st[N];
typedef multiset<node>::iterator iter;
inline void add(int x, int y) {
to[++cap] = y, nxt[cap] = head[x], head[x] = cap;
to[++cap] = x, nxt[cap] = head[y], head[y] = cap;
}
void dfs_1(int x) {
siz[x] = 1, sum[x] = val[x];
for (int i = head[x]; i; i = nxt[i])
if (to[i] != fa[x]) fa[to[i]] = x, dfs_1(to[i]), sum[x] += sum[to[i]], siz[x] += siz[to[i]];
}
void dfs_2(int x) {
f[x] = val[x];
for (int i = head[x]; i; i = nxt[i])
if (to[i] != fa[x]) st[x].insert((node){to[i], siz[to[i]], sum[to[i]]});
int presz = 1;
for (iter it = st[x].begin(); it != st[x].end(); ++it) {
dfs_2(it -> id);
f[x] += f[it -> id] + 1LL * presz * sum[it -> id];
presz += siz[it -> id];
}
}
void dfs_3(int x) {
int presz = 1, presum = val[x];
if (x != 1) {
f[x] = val[x];
st[x].insert((node){fa[x], n - siz[x], sum[1] - sum[x]});
for (iter it = st[x].begin(); it != st[x].end(); ++it) {
f[x] += f[it -> id] + 1LL * presz * (it -> su);
presz += (it -> sz);
}
}
ans = min(ans, f[x]);
presz = 1;
long long tmp = f[x];
for (iter it = st[x].begin(); it != st[x].end(); ++it) {
presum += it -> su;
if ((it -> id) != fa[x]) {
f[x] = tmp - f[it -> id] - sum[it -> id] * presz - 1LL * siz[it -> id] * (sum[1] - presum);
dfs_3(it -> id);
}
presz += it -> sz;
}
}
signed main() {
n = read();
for (int i = 1; i < n; ++i) add(read(), read());
for (int i = 1; i <= n; ++i) val[i] = read();
dfs_1(1);
dfs_2(1);
dfs_3(1);
printf("%lld\n", ans);
return 0;
}
int read() {
int x = 0, c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x;
}