luogu2486 [SDOI2011]染色

link

题目大意:给定一个N个点的树,每个点有一个颜色
有M次操作,每次可以修改树某条链所有点变成一个颜色,查询某条链上点的颜色段数

树剖,线段树维护区间合并

我的代码记录的是某个区间左端点颜色、右端点颜色、除了左端点和右端点的颜色段数

需要稍微特殊处理一些情况,详见代码

#include <cstdio>
#include <vector>
using namespace std;

struct fuck
{
    int l, r, cnt; //cnt=-1代表lcol=rcol 区间内只有一种颜色,-2代表区间是空的
    fuck(){}
    fuck(int col) : l(col), r(col), cnt(-1) {}
    fuck(int l, int r, int cnt) : l(l), r(r), cnt(cnt) {}
};

fuck operator*(const fuck &l, const fuck &r)
{
    if (l.cnt == -2) return r;
    if (r.cnt == -2) return l;
    if (l.cnt == -1 && r.cnt == -1)
    {
        if (l.l == r.l) return fuck(l.l, l.l, -1);
        else return fuck(l.l, r.l, 0);
    }
    else if (l.cnt == -1)
    {
        if (l.l == r.l) return r;
        else return fuck(l.l, r.r, r.cnt + 1);
    }
    else if (r.cnt == -1)
    {
        if (l.r == r.l) return l;
        else return fuck(l.l, r.l, l.cnt + 1);
    }
    else
    {
        return fuck(l.l, r.r, l.cnt + r.cnt + 2 - (l.r == r.l));
    }
}

vector<int> out[100010];
int n, m, col[100010];
int fa[100010], depth[100010], weight[100010], wson[100010];
int dfn[100010], top[100010], id[100010], tot;
fuck tree[400010]; int lazy[400010];

fuck rev(const fuck &x)
{
    return fuck(x.r, x.l, x.cnt);
}

void dfs1(int x)
{
    weight[x] = 1, wson[x] = -1;
    for (int i : out[x])
        if (fa[x] != i)
        {
            fa[i] = x, depth[i] = depth[x] + 1;
            dfs1(i);
            weight[x] += weight[i];
            if (wson[x] == -1 || weight[i] > weight[wson[x]]) wson[x] = i;
        }
}

void dfs2(int x, int topf)
{
    dfn[x] = ++tot, top[x] = topf, id[tot] = x;
    if (wson[x] != -1)
    {
        dfs2(wson[x], topf);
        for (int i : out[x])
            if (i != fa[x] && i != wson[x])
                dfs2(i, i);
    }
}

void build(int x, int cl, int cr)
{
    if (cl == cr)
    {
        tree[x] = fuck(col[id[cl]]);
        return;
    }
    int mid = (cl + cr) / 2;
    build(x * 2, cl, mid);
    build(x * 2 + 1, mid + 1, cr);
    tree[x] = tree[x * 2] * tree[x * 2 + 1];
}

void pushdown(int x)
{
    if (lazy[x])
    {
        tree[x * 2] = tree[x * 2 + 1] = fuck(lazy[x]);
        lazy[x * 2] = lazy[x * 2 + 1] = lazy[x];
        lazy[x] = 0;
    }
}

void chenge(int x, int cl, int cr, int L, int R, int col)
{
    if (R < cl || cr < L) return;
    if (L <= cl && cr <= R)
    {
        tree[x] = fuck(col);
        lazy[x] = col;
        return;
    }
    pushdown(x);
    int mid = (cl + cr) / 2;
    chenge(x * 2, cl, mid, L, R, col);
    chenge(x * 2 + 1, mid + 1, cr, L, R, col);
    tree[x] = tree[x * 2] * tree[x * 2 + 1];
}

fuck query(int x, int cl, int cr, int L, int R)
{
    if (R < cl || cr < L) return fuck(0, 0, -2);
    if (L <= cl && cr <= R) return tree[x];
    pushdown(x);
    int mid = (cl + cr) / 2;
    return query(x * 2, cl, mid, L, R) * query(x * 2 + 1, mid + 1, cr, L, R);
}

int main()
{
    freopen("paint.in", "r", stdin);
    freopen("paint.out", "w", stdout);
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
    for (int x, y, i = 1; i < n; i++) scanf("%d%d", &x, &y), out[x].push_back(y), out[y].push_back(x);
    dfs1(1), dfs2(1, 1);
    build(1, 1, n);
    char ch;
    for (int x, y, z, i = 1; i <= m; i++)
    {
        scanf(" %c%d%d", &ch, &x, &y);
        if (ch == 'C')
        {
            scanf("%d", &z);
            while (top[x] != top[y])
            {
                if (depth[top[x]] < depth[top[y]]) swap(x, y);
                chenge(1, 1, n, dfn[top[x]], dfn[x], z);
                x = fa[top[x]];
            }
            if (depth[x] > depth[y]) swap(x, y);
            chenge(1, 1, n, dfn[x], dfn[y], z);
        }
        else
        {
            fuck ans1(0, 0, -2), ans2(0, 0, -2);
            while (top[x] != top[y])
            {
                if (depth[top[x]] < depth[top[y]]) swap(x, y), swap(ans1, ans2);
                ans1 = query(1, 1, n, dfn[top[x]], dfn[x]) * ans1;
                x = fa[top[x]];
            }
            if (depth[x] > depth[y]) swap(x, y), swap(ans1, ans2);
            fuck ans = rev(ans1) * query(1, 1, n, dfn[x], dfn[y]) * ans2;
            printf("%d\n", ans.cnt + 2);
        }
    }
    fclose(stdin);
    fclose(stdout);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/oier/p/10351725.html
今日推荐