[ZJOI2015]幻想乡战略游戏 动态点分治

恶补了动态点分治。。

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<bits/stdc++.h>
#define LL long long
#define LD long double
#define ull unsigned long long
#define fi first
#define se second
#define mk make_pair
#define PLL pair<LL, LL>
#define PLI pair<LL, int>
#define PII pair<int, int>
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define fio ios::sync_with_stdio(false); cin.tie(0);

using namespace std;

const int N = 2e5 + 7;
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const int mod = (int)1e9 + 7;
const double eps = 1e-8;
const double PI = acos(-1);

template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;}
template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;}
template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;}
template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;}

//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

const int LOG = 20;

int n, q, root;
vector<PII> G[N];

int depth[N];
int dfn[N];
int rmq_cnt;
int Log[N << 1];
PII rmq[N << 1][LOG];

void dfs(int u, int fa) {
    dfn[u] = ++rmq_cnt;
    rmq[rmq_cnt][0] = mk(depth[u], u);
    for(auto &e : G[u]) {
        int v = e.se, w = e.fi;
        if(v == fa) continue;
        depth[v] = depth[u] + w;
        dfs(v, u);
        rmq[++rmq_cnt][0] = mk(depth[u], u);
    }
}

void calcRmq() {
    for(int i = 2; i <= rmq_cnt; i++) {
        Log[i] = Log[i >> 1] + 1;
    }
    for(int j = 1; j <= Log[rmq_cnt]; j++) {
        for(int i = 1; i + (1 << j) - 1 <= rmq_cnt; i++) {
            rmq[i][j] = min(rmq[i][j - 1], rmq[i + (1 << (j - 1))][j - 1]);
        }
    }
}

int getLca(int u, int v) {
    if(dfn[u] > dfn[v]) swap(u, v);
    int k = Log[dfn[v] - dfn[u] + 1];
    PII ret = min(rmq[dfn[u]][k], rmq[dfn[v] - (1 << k) + 1][k]);
    return ret.se;
}

int getDis(int u, int v) {
    int lca = getLca(u, v);
    return depth[u] + depth[v] - 2 * depth[lca];
}


int sz[N], mx[N], fa[N];
int center, now_tot;
bool ban[N];

void getSize(int u, int fa) {
    sz[u] = 1;
    for(auto &e : G[u]) {
        int v = e.se;
        if(v == fa || ban[v]) continue;
        getSize(v, u);
        sz[u] += sz[v];
    }
}

void findCenter(int u, int fa) {
    mx[u] = 0;
    for(auto &e : G[u]) {
        int v = e.se;
        if(v == fa || ban[v]) continue;
        findCenter(v, u);
        chkmax(mx[u], sz[v]);
    }
    chkmax(mx[u], now_tot - sz[u]);
    if(mx[center] > mx[u]) {
        center = u;
    }
}

LL val[N];
LL sum[N], d1[N], d2[N];

void divide(int u) {
    ban[u] = true;
    for(auto &e : G[u]) {
        int v = e.se, w = e.fi;
        if(ban[v]) continue;

        getSize(v, 0);
        center = 0; now_tot = sz[v];
        findCenter(v, 0);
        e.fi = center;
        fa[center] = u;

        divide(center);
    }
}

void modify(int x, int w) {
    for(int cur = x; cur; cur = fa[cur]) {
        sum[cur] += w;
        d1[cur] += 1LL * w * getDis(cur, x);
        if(fa[cur]) {
            d2[cur] += 1LL * w * getDis(fa[cur], x);
        }
    }
}

inline LL query(int x) {
    LL ans = d1[x];
    for(int cur = x; fa[cur]; cur = fa[cur]) {
        ans += d1[fa[cur]];
        ans -= d2[cur];
        ans += (sum[fa[cur]] - sum[cur]) * getDis(x, fa[cur]);
    }
    return ans;
}

LL ans;

void go(int u) {
    ans = query(u);
    for(auto &e : G[u]) {
        int v = e.se, center = e.fi;
        if(query(v) < query(u)) {
            go(center);
            break;
        }
    }
}

LL solve(int u) {
    LL ans = 0;
    for(int i = 1; i <= n; i++) {
        ans += 1LL * getDis(i, u) * val[i];
    }
    return ans;
}

int main() {
    scanf("%d%d", &n, &q);
    for(int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        G[u].push_back(mk(w, v));
        G[v].push_back(mk(w, u));
    }

    dfs(1, 0);
    calcRmq();

    mx[0] = inf;
    getSize(1, 0);
    center = 0; now_tot = sz[1];
    findCenter(1, 0);
    root = center;

    divide(center);

    while(q--) {
        int u, e;
        scanf("%d%d", &u, &e);
        modify(u, e);
        val[u] += e;
        go(root);
        printf("%lld\n", ans);
    }
    return 0;
}

/*
*/

猜你喜欢

转载自www.cnblogs.com/CJLHY/p/11536585.html