luogu3703 [SDOI2017]树点涂色(线段树+树链剖分+动态树)

link

你谷的第一篇题解没用写LCT,然后没观察懂,但是自己YY了一种不用LCT的做法

我们考虑对于每个点,维护一个fa,代表以1为根时候这个点的父亲
再维护一个bel,由于一个颜色相同的段一定是一个深度递增的链,这个代表颜色段的链顶
再维护一个ans,就是ans
那么第二个操作就是ans单点查询,第三个就是ans区间最大值
第一个操作,我们树剖把它变成了log条重链,在重链上的点的bel和ans可以区间修改
但是重链上每个点的儿子就说gg了
这些儿子的子树减去的是他们的各自bel的fa的点权,然后在加上一个1
所以说每个点还要维护一个preferred child,显然很难写(但是应该是可以实现的),还是写动态树吧

动态树其实是动态树的阉割版,我们用一个splay维护树中颜色相同的点,显然这种点一定是一段深度递增的链,于是你会发现只需要钦定1为根,操作1就是access

然后我们可以维护每个点操作3的ans,再搞个dfn,操作3就可以用线段树维护,然而每个点的答案就相当于动态树上这个点到根节点上虚边的个数+1,我们可以搞个线段树维护,在动态树access虚实变化时候搞个区间加法即可

然后你会发现对于操作2的x,y答案就是ans[x] + ans[y] - 2 * ans[lc] + 1,所以说我们还需要支持一个lca操作,写个树剖或者倍增就行了,我为了凑齐全家桶就写了个树剖

注意树剖的fa和lct的fa一开始是一个fa,lct的fa是动态维护的,树剖的fa是静态的,要开两个数组并且别弄混了,调了好长时间...

access过程中由于变换的是深度恰好为depth[x]+1的点,所以就要对x寻找后继,就是x的右儿子不停往左跳。。。

时间复杂度Nlog^2N

#include <cstdio>
#include <vector>
using namespace std;

//区间加法,区间取max
int fa[100010], dfn[100010], depth[100010], weight[100010], wson[100010], top[100010];
int tree[400010], lazy[400010], fuck[100010], ch[100010][2], fat[100010], tot, n, m;
vector<int> out[100010];

void init(int x, int cl, int cr)
{
    if (cl == cr)
    {
        tree[x] = fuck[cl];
    }
    else
    {
        int mid = (cl + cr) / 2;
        init(x * 2, cl, mid), init(x * 2 + 1, mid + 1, cr);
        tree[x] = max(tree[x * 2], tree[x * 2 + 1]);
    }
}

void pushdown(int x)
{
    tree[x * 2] += lazy[x], tree[x * 2 + 1] += lazy[x];
    lazy[x * 2] += lazy[x], lazy[x * 2 + 1] += lazy[x];
    lazy[x] = 0;
}

void chenge(int x, int cl, int cr, int L, int R, int val)
{
    if (cr < L || R < cl) return;
    if (L <= cl && cr <= R) { tree[x] += val, lazy[x] += val; return; }
    pushdown(x);
    int mid = (cl + cr) / 2;
    chenge(x * 2, cl, mid, L, R, val);
    chenge(x * 2 + 1, mid + 1, cr, L, R, val);
    tree[x] = max(tree[x * 2], tree[x * 2 + 1]);
}

int query(int x, int cl, int cr, int L, int R)
{
    if (cr < L || R < cl) return 0;
    if (L <= cl && cr <= R) return tree[x];
    pushdown(x);
    int mid = (cl + cr) / 2;
    return max(query(x * 2, cl, mid, L, R), query(x * 2 + 1, mid + 1, cr, L, R));
}

void dfs1(int x)
{
    weight[x] = 1, wson[x] = -1;
    for (int i : out[x]) if (fa[x] != i)
    {
        fa[i] = fat[i] = x, depth[i] = depth[x] + 1;
        dfs1(i), weight[x] += weight[i];
        if (wson[x] == -1 || weight[wson[x]] < weight[i]) wson[x] = i;
    }
}

void dfs2(int x, int topf)
{
    dfn[x] = ++tot, top[x] = topf, fuck[dfn[x]] = depth[x] + 1;
    if (wson[x] == -1) return;
    dfs2(wson[x], topf);
    for (int i : out[x]) if (fa[x] != i && wson[x] != i) dfs2(i, i);
}

int lca(int x, int y)
{
    while (top[x] != top[y])
    {
        if (depth[top[x]] < depth[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if (depth[x] > depth[y]) swap(x, y);
    return x;
}

bool nroot(int x) { return ch[fat[x]][0] == x || ch[fat[x]][1] == x; }

void rotate(int x)
{
    int y = fat[x], z = fat[y], k = ch[y][1] == x, w = ch[x][k ^ 1];
    if (nroot(y)) { ch[z][ch[z][1] == y] = x; } ch[x][k ^ 1] = y, ch[y][k] = w;
    if (w) { fat[w] = y; } fat[y] = x; fat[x] = z;
}

void splay(int x)
{
    while (nroot(x))
    {
        int y = fat[x], z = fat[y];
        if (nroot(y)) rotate((ch[y][1] == x) ^ (ch[z][1] == y) ? x : y);
        rotate(x);
    }
}

int findrt(int x)
{
    while (ch[x][0]) x = ch[x][0];
    return x;
}

void access(int x)
{
    for (int y = 0; x > 0; x = fat[y = x])
    {
        splay(x);
        if (ch[x][1])
        {
            int p = findrt(ch[x][1]);
            chenge(1, 1, n, dfn[p], dfn[p] + weight[p] - 1, 1);
        }
        ch[x][1] = y;
        if (y)
        {
            int p = findrt(y);
            chenge(1, 1, n, dfn[p], dfn[p] + weight[p] - 1, -1);
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int x, y, i = 1; i < n; i++) scanf("%d%d", &x, &y), out[x].push_back(y), out[y].push_back(x);
    dfs1(1), dfs2(1, 1); init(1, 1, n);
    for (int opd, x, y, i = 1; i <= m; i++)
    {
        scanf("%d%d", &opd, &x);
        if (opd == 1) access(x);
        if (opd == 2)
        {
            scanf("%d", &y);
            int lc = lca(x, y);
            printf("%d\n", query(1, 1, n, dfn[x], dfn[x]) + query(1, 1, n, dfn[y], dfn[y]) - 2 * query(1, 1, n, dfn[lc], dfn[lc]) + 1);
        }
        if (opd == 3) printf("%d\n", query(1, 1, n, dfn[x], dfn[x] + weight[x] - 1));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/oier/p/10466519.html