牛客练习赛55 E 树

题目链接:

题意:给出n个点,n-1条边求任意两个点的距离平方的和

解法:

f[i]表示这个点的高度

sz[i]表示这个子树的大小

szz[i]表示这个这个子树大小的平方

sum[i]表示这个子树所有点高度的和

两个点i, j的距离dis = f[i] + f[j] - 2 * f[lca(i, j)]

dis的平方 =  f[i] * f[i] + f[j] * f[j] + 2 * f[i] * f[j] * 4 * f[lca(i, j)] * f[lca(i, j)]  - 4 *  (f[i] + f[j]) * f[lca(i, j)]

前面三项直接计算即可

计算第四项就要计算点u为lca的(i, j)对数

以点u为lca的(i, j)对数为sz[u] * sz[u] - sz[v] * sz[v](v为u的所有儿子)

计算第五项就要计算以点u为lca的f[i]和f[j]的和

以点u为lca的f[i]和f[j]的和为2 * sum[v] * (sz[u] - sz[v]) (v为u的所有儿子)(乘2是因为(i, j),(j, i)都要计算)

最后再加上点u自身的贡献:2 * f[u] * sz[u]

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int M = 1e6 + 10;
ll ans;
int cnt;
int head[M];
ll sz[M], szz[M], f[M], sum[M];
struct Edge{
    int next, to;
}edge[M * 2];
void add_edge(int u, int v) {
    edge[++cnt].next = head[u];
    edge[cnt].to = v;
    head[u] = cnt;
}
void dfs(int u, int fa) {
    sz[u] = 1;
    f[u] = f[fa] + 1;
    sum[u] = f[u];
    for(int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if(v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        sum[u] += sum[v];
    }
    szz[u] = (sz[u] % mod) * (sz[u] % mod) % mod;
}
void dfs1(int u, int fa) {
    ll ans1 = szz[u], ans2 = (2 * f[u] % mod) * (sz[u] % mod) % mod;
    for(int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if(v == fa) continue;
        dfs1(v, u);
        ans1 = (ans1 - szz[v] + mod) % mod;
        ans2 = (ans2 % mod + (2 * sum[v] % mod) * ((sz[u] - sz[v]) % mod) % mod) % mod;
    }
    ans = (ans % mod + ((4 * f[u] % mod) * (f[u] % mod) % mod) * (ans1 % mod) % mod) % mod;
    ans = (ans % mod - (4 * ans2) % mod * (f[u] % mod) % mod + mod ) % mod;
}
int main(){
    int n;
    while(~scanf("%d", &n)) {
        cnt = 0;
        ans = 0;
        memset(head, 0, sizeof(head));
        for(int i = 1; i <= n - 1; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            add_edge(u, v);
            add_edge(v, u);
        }
        dfs(1, 0);
        ll summ = 0;
        for(int i = 1; i <= n; i++) {
            summ = (summ % mod + f[i] % mod) % mod;
            ans = (ans % mod + ((f[i] % mod) * (f[i] % mod) % mod ) * ((2 * n) % mod )% mod) % mod;
        }
        ans = (ans % mod + (2 * summ % mod) * (summ % mod) % mod) % mod;
        dfs1(1, 0);
        printf("%lld\n", ans);
    }
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/linglinga/p/12077045.html