某模拟赛C题 树上路径统计 (点分治)

题意

给定一棵有n个节点的无根树,树上的每个点有一个非负整数点权。定义一条路径的价值为路径上的点权和-路径上的点权最大值。 给定参数P,我!=们想知道,有多少不同的树上简单路径,满足它的价值恰好是P的倍数。 注意:单点算作一条路径;u!=v时,(u,v)和(v,u)只算一次。

题解

树上路径统计,解法是点分治。点分的时候求出根到每个点路径最大值和权值和。排一序,然后开个桶,就能计算了。去重就套路的减去没棵子树里面的答案。

CODE

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
typedef long long LL;
LL ans;
int n, mod, fir[MAXN], nxt[MAXN<<1], to[MAXN<<1], cnt, val[MAXN];
inline void link(int x, int y) {
    to[++cnt] = y; nxt[cnt] = fir[x]; fir[x] = cnt;
    to[++cnt] = x; nxt[cnt] = fir[y]; fir[y] = cnt;
}
bool ban[MAXN];
int getsz(int u, int ff) {
    int re = 1;
    for(int v, i = fir[u]; i; i = nxt[i])
        if((v=to[i]) != ff && !ban[v])
            re += getsz(v, u);
    return re;
}
int getrt(int u, int ff, int &rt, int Size) {
    int re = 1; bool can = 1;
    for(int v, tmp, i = fir[u]; i; i = nxt[i])
        if((v=to[i]) != ff && !ban[v]) {
            re += (tmp = getrt(v, u, rt, Size));
            if((tmp<<1) > Size) can = 0;
        }
    if(((Size-re)<<1) > Size) can = 0;
    if(can) rt = u;
    return re;
}
struct node {
    int mx, v;
    inline bool operator <(const node &o)const {
        return mx < o.mx;
    }
}seq[MAXN], vv[MAXN];
int tot;
void dfs(int u, int ff, int mx, int vs) {
    vs = (vs + val[u]) % mod;
    mx = max(mx, val[u]);
    vv[u] = (node){ mx, vs };
    for(int v, i = fir[u]; i; i = nxt[i])
        if((v=to[i]) != ff && !ban[v])
            dfs(v, u, mx, vs);
}
void push(int u, int ff) {
    seq[++tot] = vv[u];
    for(int v, i = fir[u]; i; i = nxt[i])
        if((v=to[i]) != ff && !ban[v])
            push(v, u);
}
int bin[10000005];
LL calc(int rt, int o) {
    tot = 0; push(rt, 0);
    sort(seq + 1, seq + tot + 1);
    LL re = 0;
    for(int i = 1; i <= tot; ++i) {
        re += bin[((seq[i].mx+o-seq[i].v)%mod+mod)%mod];
        ++bin[seq[i].v%mod];
    }
    for(int i = 1; i <= tot; ++i) --bin[seq[i].v%mod];
    return re;
}
void solve(int x) {
    dfs(x, 0, 0, 0);
    ans += calc(x, val[x]);
    ban[x] = 1;
    for(int v, i = fir[x]; i; i = nxt[i])
        if(!ban[v=to[i]]) ans -= calc(v, val[x]);
}
void TDC(int x) {
    int Size = getsz(x, 0);
    getrt(x, 0, x, Size);
    solve(x);
    for(int v, i = fir[x]; i; i = nxt[i])
        if(!ban[v=to[i]]) TDC(v);
}
int main () {
    scanf("%d%d", &n, &mod);
    for(int i = 1, x, y; i < n; ++i)
        scanf("%d%d", &x, &y), link(x, y);
    for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
    TDC(1);
    printf("%lld\n", ans+n);
}

猜你喜欢

转载自www.cnblogs.com/Orz-IE/p/12076356.html