「BJOI2018」求和 - 树上前缀和

题意

每个节点权值为其深度,每次给出 \(k ( \le 50)\) 询问一些节点间的最小距离

思路

考虑到 \(k\) 最多只有 \(50\) 个取值,离线做每一个 \(k\) 的询问的答案

对每个 \(k\) 先进行 dfs 求出链上前缀和 $ ans = sum_x + sum_y - sum_{lca(x, y)} - sum_{fa_lca(x, y)}$

lca求发利用倍增,复杂度: \(O(kn + mlogn)\)

#include <cstdio>
#include <algorithm>

const int maxp = 3e5 + 10, mod = 998244353;
const int maxe = maxp << 1;

#define ll long long

int dep[maxp], fa[maxp][31];
int ver[maxe], Next[maxe], head[maxp], cnt;
int n, m, r[maxp];

ll val[maxp], res[maxp];

inline void read(int &x) {
  int ch = getchar(); x = 0;
  for (; ch < '0' || ch > '9'; ch = getchar());
  for (; ch >='0' && ch <='9'; ch = getchar())
    x = (x << 1) + (x << 3) + (ch ^ 48);
}

ll qpow(ll x, int y) {
  ll ans = 1;
  for (; y; y >>= 1, x = x * x % mod)
    if (y & 1) ans = 1ll * ans * x % mod;
  return ans;
}

struct Query {
  int x, y, k;
}a[maxp];

bool cmp(int x, int y) {
  return a[x].k < a[y].k;
}

inline void dfs(int o, int fno) {
  dep[o] = dep[fa[o][0] = fno] + 1;
  for (int i = 1; i < 31; ++ i)
    fa[o][i] = fa[fa[o][i - 1]][i - 1];
  for (int i = head[o], v; i; i = Next[i]) {
    if ((v = ver[i]) == fno) continue;
    dfs(v, o);
  }
}

int lca(int x, int y) {
  if (dep[x] > dep[y]) std::swap(x, y);
  int tmp = dep[y] - dep[x];
  for (int j = 0; tmp; ++ j, tmp >>= 1)
    if (tmp & 1) y = fa[y][j];
  if (x == y) return x;
  for (int j = 30; ~j && x != y; -- j) {
    if (fa[x][j] != fa[y][j]) {
      x = fa[x][j];
      y = fa[y][j];
    }
  }
  return fa[x][0];
}

inline void dfs_base(int o, int fno, int k) {
  val[o] = (val[fno] + qpow(dep[o] - 1, k)) % mod;
  for (int i = head[o], v; i; i = Next[i]) {
    if ((v = ver[i]) == fno) continue;
    dfs_base(v, o, k);
  }
}

inline void addline(const int &x, const int &y) {
  ver[++cnt] = y, Next[cnt] = head[x], head[x] = cnt;
}

int main() {
  read(n);
  for (int i = 1, x, y; i < n; ++ i) {
    read(x), read(y); addline(x, y), addline(y, x);
  }
  dfs(1, 0); 
  read(m);
  for (int i = 1; i <= m; ++ i)
    read(a[i].x), read(a[i].y), read(a[i].k), r[i] = i;
  std::sort(r + 1, r + m + 1, cmp);
  for (int i = 1; i <= m; ++ i) {
    if (a[r[i]].k != a[r[i - 1]].k) {
      dfs_base(1, 0, a[r[i]].k);
    }
    res[r[i]] = (val[a[r[i]].x] + val[a[r[i]].y]) % mod;
    int LCA = lca(a[r[i]].x, a[r[i]].y);
    res[r[i]] = (mod + res[r[i]] - val[LCA]) % mod;
    res[r[i]] = (mod + res[r[i]] - val[fa[LCA][0]]) % mod;
  }
  for (int i = 1; i <= m; ++ i) printf("%lld\n", res[i]);
}

猜你喜欢

转载自www.cnblogs.com/alessandrochen/p/11482793.html