【洛谷3345_BZOJ3924】[ZJOI2015]幻想乡战略游戏(点分树)

大概有整整一个月没更博客了 ……

4 月为省选爆肝了一个月,最后压线进 B 队,也算给 NOIP2018 翻车到 316 分压线省一这个折磨了五个月的 debuff 画上了一个不算太差的句号。结果省选后技能点全点到红警上了,OI 迅速变菜,GG 。

题目:

洛谷 3345

分析:

为什么我觉得这题网上大部分题解都讲的很麻烦,看了一上午还没看懂,有一种被拐到沟里的感觉 …… 我这个思路自认为比较好理解。

先考虑一个比较弱的问题:把原问题的「寻找最优补给站」改为「每次询问钦定一个补给站,求此时的花费」。

这很明显是个点分树裸题,建议先充分理解 BZOJ3730 震波(此处应有本人博客链接,无限咕咕中)。非常类似于震波的做法,对每个结点 \(u\) 维护 \(sum_u\)\(sumd_u\)\(sumf_u\)\(sumdf_u\) ,分别表示点分树上结点 \(u\) 「管辖范围」(可以理解为点分治时 get_path 函数遍历的那棵子树)中的点权之和、「点权乘深度(到根节点的边权之和,下同)」之和、对点分树上父节点贡献的点权之和、对点分树上父节点贡献的「点权乘深度」之和。

P.S. 写博客的时候突然发现 \(sum_u\)\(sumf_u\) 其实是同一个东西,但我还是要为了「对称美」以及套震波的板子把它们分开了(滑稽)。

代码如下,如果充分理解了 BZOJ3730 的做法应该很好懂。代码中用 \(sum_{u+n}\)\(sumd_{u+n}\) 表示 \(sumf_u\)\(sumdf_u\) 。用 ST 表求 LCA ,单次修改和查询都是 \(O\left(\log n\right)\) 的。

void modify(const int u, const int x)
{
    using LCA::get_dis;
    int tmp = u;
    wtot += x;
    d[u] += x;
    sum[u] += x;
    while (fa[tmp])
    {
        int d = get_dis(u, fa[tmp]);
        sum[fa[tmp]] += x;
        sum[tmp + n] += x;
        sumd[fa[tmp]] += (ll)x * d;
        sumd[tmp + n] += (ll)x * d;
        tmp = fa[tmp];
    }
}
ll query(const int u)
{
    using LCA::get_dis;
    int tmp = u;
    ll ans = sumd[u];
    while (fa[tmp])
    {
        int d = get_dis(u, fa[tmp]);
        ans += sumd[fa[tmp]] - sumd[tmp + n] + (ll)d * (sum[fa[tmp]] - sum[tmp + n]);
        tmp = fa[tmp];
    }
    return ans;
}

暂时忘掉点分树(现在它的作用只是在 \(O(\log n)\) 的时间内求把补给站设在某个点的答案),只想原树,考虑一种贪心的做法:先随便站在一个点上,称为 \(u\)\(v\) 是一个与 \(u\) 的一个直接相连的结点(以下称为「儿子」)。如果 \(v\) 的答案比 \(u\) 更优,则从 \(u\) 走到 \(v\) 。如此往复,最终到一个无法再走的点,那么这个点就是最优点。口胡的不严谨证明如下:

\(sum_v\) 表示以 \(u\) 为根时 \(v\) 子树中点的点权之和, \(tot\) 表示全部 \(n\) 个点的点权之和,\(w\) 是边 \((u,v)\) 的权。那么,从 \(u\) 走到 \(v\)\(v\) 子树中的所有点 \(p\) 的贡献减少 \(w\cdot d_p\)\(v\) 子树外的所有点 \(q\) 的贡献增加 \(w\cdot d_q\) ,那么答案变化量 \(\Delta=w\cdot (tot-sum_v)-w\cdot sum_v=w\cdot (tot-2sum_v)\) 。也就是说,只有当 \(2sum_v>tot\) ,答案才会减少,即 \(v\) 的答案比 \(u\) 更优。很明显,\(u\) 最多只能有一个儿子 \(v\) 满足 \(2sum_v>tot\) ,所以如果能移动,一定只有唯一的一种移动方案。并且,如果 \(2sum_v>tot\) ,则 \(2(tot-sum_v)\) (即以 \(v\) 为根时的 \(sum_u\) 的两倍)一定不大于 \(tot\) ,所以不可能往回移动。综上,这样一定能找到最优解。

如果专门造数据卡,上述贪心每次最多能走 \(n-1\) 步(考虑一条长链,中间全是 \(0\) ,两端轮流在 \(1\)\(0\) 之间切换,最优解轮流出现在两个端点上),单次修改最坏 \(O(n\log n)\) ,会 TLE 。然而,这个贪心给我们一个重要的启示:对于相邻两点 \(u\)\(v\) ,如果 \(u\)\(v\) 优,那么答案 一定在 \(v\) 的子树中 (以 \(u\) 为树根)。换句话说,就是如果断掉边 \((u,v)\) ,则 答案一定在 \(v\) 所在的连通块中 。也就是说,我即使现在不走到 \(v\) ,只要走到(更形象地说,「跳到」) \(v\) 的子树中任意一点,也都能保证最终找到最优解。

那么我们每次不是走到 \(v\) ,而是走到 \(v\) 这棵子树的「重心」。从点分树的角度来说,记 \(near_v\) 表示从 \(v\) 的点分树父亲 \(u\)\(v\) 的原树路径上除了 \(u\) 以外的第一个点。一开始站在根上,从 \(u\) 走到 \(v\) 的条件是 \(near_v\)\(u\) 更优,最终无路可走了就是答案。由于点分树深度是 \(O(\log n)\) ,每次最多走深度步,每走一步要 \(O(\log n)\) 查询若干个点的答案,所以单次查询时间复杂度为 \(O(\log^2 n)\) (由于要遍历所有儿子,所以要乘上最大度数 \(20\) 的常数)。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <vector>
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = x * 10 + c - '0', c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar(' '), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 1e5 + 10, B = 20, INF = 0x3f3f3f3f;
    struct edge
    {
        int to, w, next;
    }e[N << 1];
    int n, head[N], ecnt, d[N];
    void add(const int a, const int b, const int c)
    {
        e[ecnt] = (edge){b, c, head[a]}, head[a] = ecnt++;
    }
    namespace LCA
    {
        int dis[N], dfn[N], euler[N << 1], dfncnt;
        namespace ST
        {
            int lg2[N << 1], st[B][N << 1];
            const int *w;
            int min(const int a, const int b)
            {
                return w[a] < w[b] ? a : b;
            }
            void build(const int *_w, const int n)
            {
                w = _w;
                int tmp = 0;
                for (int i = 1; i <= n; i++)
                {
                    lg2[i] = tmp;
                    if (i == (1 << (tmp + 1)))
                        ++tmp;
                }
                for (int i = 1; i <= n; i++)
                    st[0][i] = i;
                for (int i = 1; i < B; i++)
                    for (int j = 1; j + (1 << i) - 1 <= n; j++)
                        st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
            }
            int query(const int l, const int r)
            {
                int len = lg2[r - l + 1];
                return min(st[len][l], st[len][r - (1 << len) + 1]);
            }
        }
        void dfs(const int u, const int f)
        {
            dfn[u] = ++dfncnt;
            euler[dfncnt] = u;
            for (int i = head[u]; ~i; i = e[i].next)
            {
                int v = e[i].to;
                if (v == f)
                    continue;
                dis[v] = dis[u] + e[i].w;
                dfs(v, u);
                euler[++dfncnt] = u;
            }
        }
        void init()
        {
            dfncnt = 0;
            dfs(1, 0);
            ST::build(euler, dfncnt);
        }
        int lca(const int a, const int b)
        {
            return euler[ST::query(min(dfn[a], dfn[b]), max(dfn[a], dfn[b]))];
        }
        int get_dis(const int a, const int b)
        {
            return dis[a] + dis[b] - (dis[lca(a, b)] << 1);
        }
    }
    namespace Point_Divide_Tree
    {
        int f[N], near[N], rot, size[N], tot, fa[N];
        ll sum[N << 1], sumd[N << 1], wtot;
        bool vis[N];
        vector<int> g[N];
        void find_rot(const int u, const int fa)
        {
            size[u] = 1, f[u] = 0;
            for (int i = head[u]; ~i; i = e[i].next)
            {
                int v = e[i].to;
                if (vis[v] || v == fa)
                    continue;
                find_rot(v, u);
                size[u] += size[v];
                f[u] = max(f[u], size[v]);
            }
            f[u] = max(f[u], tot - size[u]);
            if (f[u] < f[rot])
                rot = u;
        }
        int get_size(const int u, const int f)
        {
            int ans = 1;
            for (int i = head[u]; ~i; i = e[i].next)
            {
                int v = e[i].to;
                if (v == f || vis[v])
                    continue;
                ans += get_size(v, u);
            }
            return ans;
        }
        void solve(const int u)
        {
            vis[u] = true;
            for (int i = head[u]; ~i; i = e[i].next)
            {
                int v = e[i].to;
                if (vis[v])
                    continue;
                tot = get_size(v, u);
                f[0] = INF, rot = 0;
                find_rot(v, u);
                fa[rot] = u, g[u].push_back(rot);
                near[rot] = v;
                solve(rot);
            }
        }
        void modify(const int u, const int x)
        {
            using LCA::get_dis;
            int tmp = u;
            wtot += x;
            d[u] += x;
            sum[u] += x;
            while (fa[tmp])
            {
                int d = get_dis(u, fa[tmp]);
                sum[fa[tmp]] += x;
                sum[tmp + n] += x;
                sumd[fa[tmp]] += (ll)x * d;
                sumd[tmp + n] += (ll)x * d;
                tmp = fa[tmp];
            }
        }
        ll query(const int u)
        {
            using LCA::get_dis;
            int tmp = u;
            ll ans = sumd[u];
            while (fa[tmp])
            {
                int d = get_dis(u, fa[tmp]);
                ans += sumd[fa[tmp]] - sumd[tmp + n] + (ll)d * (sum[fa[tmp]] - sum[tmp + n]);
                tmp = fa[tmp];
            }
            return ans;
        }
        ll find(const int u)
        {
            ll now = query(u);
            for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it++)
            {
                int v = *it;
                if (query(near[v]) < now)
                    return find(v);
            }
            return now;
        }
    }
    int work()
    {
        using namespace Point_Divide_Tree;
        int q;
        read(n), read(q);
        memset(head, -1, sizeof(int[n + 1]));
        for (int i = 1; i < n; i++)
        {
            int a, b, c;
            read(a), read(b), read(c);
            add(a, b, c), add(b, a, c);
        }
        LCA::init();
        tot = n;
        f[rot = 0] = INF;
        find_rot(1, 0);
        int root = rot;
        solve(rot);
        while (q--)
        {
            int u, e;
            read(u), read(e);
            modify(u, e);
            write(find(root)), putchar('\n');
        }
        return 0;
    }
}
int main()
{
    return zyt::work();
}

猜你喜欢

转载自www.cnblogs.com/zyt1253679098/p/10835943.html