HNOI2015 开店

一个有点权和边权的二叉树,多次询问点权在 $[L,R]$ 的点到 $u$ 的距离和

$n,q \leq 100000$

sol:

1.点分治

建出分治树的结构,考虑计算距离的过程

我们知道 $dis(u,v) = dep_u + dep_v - 2 \times dep_{lca}$

因为树高是 logn 的,所以可以暴力爬树高枚举 lca

把点权差分一下

对每层重心开 $3$ 个 vector 表示前 $i$ 种颜色到它的距离和,前 $i$ 种颜色到它父亲的距离和,前 $i$ 种颜色的点数

因为树高是 $O(logn)$ 的,这个空间是 $O(nlogn)$ 的

每次二分找到 $[L,R]$ 在当前 vector 里的位置,算一下距离就可以了

顺便吐槽为什么我每次点分治都懒得写 $O(1)$ lca

#include <bits/stdc++.h>
#define LL long long
using namespace std;
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    for (; !isdigit(ch); ch = getchar())
        if (ch == '-')
            f = -f;
    for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 200010;
int n, q, A, yr[maxn];
int first[maxn], to[maxn << 1], nx[maxn << 1], val[maxn << 1], cnt;
inline void add(int u, int v, int w) {
    to[++cnt] = v;
    nx[cnt] = first[u];
    first[u] = cnt;
    val[cnt] = w;
}
LL dis[maxn];
int fa[maxn];
namespace LCA {
int dep[maxn], bl[maxn], size[maxn];
inline void dfs1(int x) {
    size[x] = 1;
    for (int i = first[x]; i; i = nx[i]) {
        if (to[i] == fa[x])
            continue;
        fa[to[i]] = x;
        dep[to[i]] = dep[x] + 1;
        dis[to[i]] = dis[x] + val[i];
        dfs1(to[i]);
        size[x] += size[to[i]];
    }
}
inline void dfs2(int x, int col) {
    int k = 0;
    bl[x] = col;
    for (int i = first[x]; i; i = nx[i])
        if (dep[to[i]] > dep[x] && size[to[i]] > size[k])
            k = to[i];
    if (!k)
        return;
    dfs2(k, col);
    for (int i = first[x]; i; i = nx[i])
        if (dep[to[i]] > dep[x] && to[i] != k)
            dfs2(to[i], to[i]);
}
inline int lca(int x, int y) {
    while (bl[x] != bl[y]) {
        if (dep[bl[x]] < dep[bl[y]])
            swap(x, y);
        x = fa[bl[x]];
    }
    return dep[x] > dep[y] ? y : x;
}
}  // namespace LCA
struct Node {
    LL col, sum, sig, cnt;
    inline bool operator<(const Node &b) const { return col < b.col; }
};
vector<Node> G[maxn];
inline LL caldis(int x, int y) {
    //    cout<<dis[x] + dis[y] - 2 * dis[LCA::lca(x,y)]<<endl;
    if (!x || !y)
        return 0;
    return dis[x] + dis[y] - 2 * dis[LCA::lca(x, y)];
}
int f[maxn], size[maxn], vis[maxn], par[maxn], sig, root;
void findroot(int x, int fa) {
    f[x] = 0, size[x] = 1;
    for (int i = first[x]; i; i = nx[i]) {
        if (to[i] == fa || vis[to[i]])
            continue;
        findroot(to[i], x);
        size[x] += size[to[i]];
        f[x] = max(f[x], size[to[i]]);
    }
    f[x] = max(f[x], sig - size[x]);
    if (f[x] < f[root])
        root = x;
}
void add_node(int x, int fa, int rt) {
    G[rt].push_back((Node){ yr[x], caldis(x, rt), (par[rt] ? caldis(x, par[rt]) : 0), 1 });

    for (int i = first[x]; i; i = nx[i]) {
        if (to[i] == fa || vis[to[i]])
            continue;
        add_node(to[i], x, rt);
    }
}
void build(int x) {
    vis[x] = 1;
    add_node(x, 0, x);
    G[x].push_back((Node){ -1, 0, 0, 0 });
    sort(G[x].begin(), G[x].end());
    for (int i = 1; i < G[x].size(); i++) {
        G[x][i].sum += G[x][i - 1].sum;
        G[x][i].sig += G[x][i - 1].sig;
        G[x][i].cnt += G[x][i - 1].cnt;
    }
    for (int i = first[x]; i; i = nx[i]) {
        if (vis[to[i]])
            continue;
        root = 0;
        sig = size[to[i]];
        findroot(to[i], 0);
        par[root] = x;
        build(root);
    }
}
LL query(int x, int ql, int qr) {
    LL ans = 0;
    for (int i = x; i; i = par[i]) {
        int st, ed;
        int l = 0, r = G[i].size() - 1;
        while (l <= r) {
            int mid = (l + r) >> 1;
            if (G[i][mid].col <= qr)
                l = mid + 1;
            else
                r = mid - 1;
        }
        ed = l - 1;
        l = 0, r = G[i].size() - 1;
        while (l <= r) {
            int mid = (l + r) >> 1;
            if (G[i][mid].col <= ql - 1)
                l = mid + 1;
            else
                r = mid - 1;
        }
        st = l - 1;
        //    cout<<st<<" "<<ed<<endl;
        ans += (G[i][ed].sum - G[i][st].sum);
        if (i != x)
            ans += (G[i][ed].cnt - G[i][st].cnt) * caldis(i, x);
        if (par[i])
            ans -= (G[i][ed].sig - G[i][st].sig) + (G[i][ed].cnt - G[i][st].cnt) * caldis(x, par[i]);
    }
    return ans;
}
int main() {
    n = read(), q = read(), A = read();
    for (int i = 1; i <= n; i++) yr[i] = read();
    for (int i = 2; i <= n; i++) {
        int u = read(), v = read(), w = read();
        add(u, v, w);
        add(v, u, w);
    }
    LCA::dep[1] = 1;
    LCA::dfs1(1);
    LCA::dfs2(1, 1);
    sig = n;
    size[0] = f[0] = 2147483233;
    findroot(1, 0);
    build(root);
    LL lastans = 0;
    while (q--) {
        int x = read(), a = read(), b = read();
        int l = min((a + lastans) % A, (b + lastans) % A);
        int r = max((a + lastans) % A, (b + lastans) % A);
        printf("%lld\n", lastans = query(x, l, r));
    }
}
点分治

2.主席树

以点权为版本开主席树,还是考虑计算距离,发现 $dep_u$ 和 $dep_v$ 都可以直接查,$dep_lca$ 的话,不好查

可以把所有 $v$ 到根的路径全 $+1$,然后询问的时候从每个 $u$ 走到根,在相应的 $v$ 的线段树上查到根距离就可以了

主席树跟上一种做法一样,也是开一个关于点权前缀的,查询的时候减一下

#include <bits/stdc++.h>
#define LL long long
using namespace std;
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    for (; !isdigit(ch); ch = getchar())
        if (ch == '-')
            f = -f;
    for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0';
    return x * f;
}
const int maxn = 200010;
int n, q, A;
struct Node {
    int yr, id;
    bool operator<(const Node &b) const { return (yr == b.yr) ? (id < b.id) : (yr < b.yr); }
} ns[maxn];
struct TrNode {
    int ls, rs, tms;
    LL val;
} t[maxn << 8];
int first[maxn], to[maxn << 1], nx[maxn << 1], val[maxn << 1], cnt;
inline void add(int u, int v, int w) {
    to[++cnt] = v;
    nx[cnt] = first[u];
    first[u] = cnt;
    val[cnt] = w;
}
int ToT, root[maxn];
LL sum[maxn], dis[maxn], dsum[maxn];
int fa[maxn], pos[maxn], dfn;
int dep[maxn], bl[maxn], size[maxn];
inline void dfs1(int x) {
    size[x] = 1;
    for (int i = first[x]; i; i = nx[i]) {
        if (to[i] == fa[x])
            continue;
        fa[to[i]] = x;
        dep[to[i]] = dep[x] + 1;
        dis[to[i]] = dis[x] + val[i];
        dfs1(to[i]);
        size[x] += size[to[i]];
    }
}
inline void dfs2(int x, int col) {
    int k = 0;
    bl[x] = col;
    pos[x] = ++dfn;
    sum[dfn] = dis[x] - dis[fa[x]];
    // cout<<sum[dfn] << endl;
    for (int i = first[x]; i; i = nx[i])
        if (dep[to[i]] > dep[x] && size[to[i]] > size[k])
            k = to[i];
    if (!k)
        return;
    dfs2(k, col);
    for (int i = first[x]; i; i = nx[i])
        if (dep[to[i]] > dep[x] && to[i] != k)
            dfs2(to[i], to[i]);
}
inline void build(int &x, int l, int r) {
    x = ++ToT;
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    build(t[x].ls, l, mid);
    build(t[x].rs, mid + 1, r);
}
inline void Insert(int &x, int l, int r, int L, int R) {
    t[++ToT] = t[x];
    if (L <= l && r <= R) {
        t[x = ToT].tms++;
        return;
    }
    t[x = ToT].val += sum[min(R, r)] - sum[max(l - 1, L - 1)];
    int mid = (l + r) >> 1; /*
     if (R <= mid) Insert(t[x].ls,l,mid,L,R);
     else if (L > mid) Insert(t[x].rs,mid + 1,r,L,R);
     else Insert(t[x].ls,l,mid,L,mid),Insert(t[x].rs,mid + 1,r,mid + 1,R);*/
    if (L <= mid)
        Insert(t[x].ls, l, mid, L, R);
    if (R > mid)
        Insert(t[x].rs, mid + 1, r, L, R);
}
inline LL query(int x, int l, int r, int L, int R) {
    LL res = 1LL * (sum[min(R, r)] - sum[max(l - 1, L - 1)]) * t[x].tms;
    if (L <= l && r <= R)
        return res + t[x].val;
    int mid = (l + r) >> 1; /*
    if (R <= mid) return res + query(t[x].ls,l,mid,L,R);
    else if (L > mid) return res + query(t[x].rs,mid + 1,r,L,R);
    else return res + query(t[x].ls,l,mid,L,mid) + query(t[x].rs,mid + 1,r,mid + 1,R);*/
    if (L <= mid)
        res += query(t[x].ls, l, mid, L, R);
    if (R > mid)
        res += query(t[x].rs, mid + 1, r, L, R);
    return res;
}
inline LL ask(int u, int v) {
    LL res = 0;
    while (bl[u] != 1) {
        res += query(root[v], 1, n, pos[bl[u]], pos[u]);
        u = fa[bl[u]];
    }
    res += query(root[v], 1, n, 1, pos[u]);
    return res;
}
inline void add(int u, int v) {
    while (bl[u] != 1) {
        Insert(root[v], 1, n, pos[bl[u]], pos[u]);
        u = fa[bl[u]];
    }
    Insert(root[v], 1, n, 1, pos[u]);
}
int main() {
    n = read(), q = read(), A = read();
    for (int i = 1; i <= n; i++) ns[i].yr = read(), ns[i].id = i;
    sort(ns + 1, ns + n + 1);
    for (int i = 2; i <= n; i++) {
        int u = read(), v = read(), w = read();
        add(u, v, w);
        add(v, u, w);
    }
    dfs1(1);
    dfs2(1, 1);
    for (int i = 1; i <= n; i++) sum[i] += sum[i - 1], dsum[i] = dsum[i - 1] + dis[ns[i].id];
    build(root[0], 1, n);
    for (int i = 1; i <= n; i++) {
        int u = ns[i].id;
        root[i] = root[i - 1];
        while (bl[u] != 1) {
            Insert(root[i], 1, n, pos[bl[u]], pos[u]);
            u = fa[bl[u]];
        }
        Insert(root[i], 1, n, 1, pos[u]);
    }
    LL lastans = 0;
    while (q--) {
        int u = read(), a = read(), b = read();
        int l = min((a + lastans) % A, (b + lastans) % A);
        int r = max((a + lastans) % A, (b + lastans) % A);
        l = lower_bound(ns + 1, ns + n + 1, (Node){ l, 0 }) - ns;
        r = upper_bound(ns + 1, ns + n + 1, (Node){ r, n }) - ns - 1;
        // cout<<l<<" "<<r<<endl;
        printf("%lld\n", lastans = 1LL * (r - l + 1) * dis[u] + dsum[r] - dsum[l - 1] -
                                   2 * (ask(u, r) - ask(u, l - 1)));
    }
}
主席树

猜你喜欢

转载自www.cnblogs.com/Kong-Ruo/p/10214762.html
今日推荐