[BJOI 2018] Summation

Description

Question bank link

Given a rooted tree with \(n\) nodes, ask \(m\) times the sum of the \(k\) powers of the depths of all nodes on a path on the tree .

\ (1 \ leq n \ leq 300000,1 \ leq k \ leq 50 \)

Solution

After tree sectioning, \(50\) prefixes and arrays are maintained.

Only by brushing water problems can we maintain a life like this...

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 300000+5, yzh = 998244353;
int gi() {
    int x = 0; char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x<<1)+(x<<3)+ch-48, ch = getchar();
    return x;
}

int n, m, u, v, k, sum[51][N];
struct tt {int to, next; }edge[N<<1];
int path[N], TP;
int size[N], top[N], fa[N], dep[N], id[N], son[N], idx;

void dfs1(int u, int depth, int father) {
    dep[u] = depth, size[u] = 1, fa[u] = father;
    for (int i = path[u]; i; i = edge[i].next)
        if (edge[i].to != father) {
            dfs1(edge[i].to, depth+1, u);
            size[u] += size[edge[i].to];
            if (size[edge[i].to] > size[son[u]]) son[u] = edge[i].to;
        }
}
void dfs2(int u, int tp) {
    top[u] = tp, id[u] = ++idx;
    if (son[u]) dfs2(son[u], tp);
    for (int i = path[u]; i; i = edge[i].next)
        if (edge[i].to != fa[u] && edge[i].to != son[u])
            dfs2(edge[i].to, edge[i].to);
}
void add(int u, int v) {edge[++TP] = (tt){v, path[u]}; path[u] = TP; }
int cal(int u, int v, int k) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        (ans += (sum[k][id[u]]-sum[k][id[top[u]]-1]+yzh)%yzh) %= yzh;
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    (ans += (sum[k][id[u]]-sum[k][id[v]-1]+yzh)%yzh) %= yzh;
    return ans;
}
void work() {
    n = gi();
    for (int i = 1; i < n; i++) {
        u = gi(), v = gi(); add(u, v), add(v, u);
    }
    dfs1(1, 0, 0); dfs2(1, 1);
    for (int i = 1; i <= n; i++)
        for (int j = 1, val = dep[i]; j <= 50; j++, val = 1ll*val*dep[i]%yzh)
            sum[j][id[i]] = val;
    for (int i = 1; i <= 50; i++)
        for (int j = 1; j <= n; j++)
            (sum[i][j] += sum[i][j-1]) %= yzh;
    m = gi();
    while (m--) {
        u = gi(), v = gi(), k = gi();
        printf("%d\n", cal(u, v, k));
    }
}
int main() {work(); return 0; }

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324980614&siteId=291194637