虚树——学习笔记

虚树

虚树的题一般都是多次询问的树形 d p dp 问题,每次选定树上 K i K_i 个关键点,求出答案
一般来说, K i \sum K_i 不会太大, 1 0 5 1 0 6 10^5-10^6 左右,这样一般的树形 d p dp 总体复杂度就是 O ( K i ) O(\sum K_i)
下面就以绿色的点为关键点,建立一个虚树:
在这里插入图片描述

建立的虚树长这样:
在这里插入图片描述

可以看到,有几个点已经消失了
在虚树上,我们留下的是关键点,以及关键点之间的最近公共祖先 l c a lca
这样做很好的保留了原树的结构和信息,并且
节点少了,我们遍历整棵树所需要的时间就少了
(可以证明,我们建立的虚树中节点个数不超过 2 K 2K )
虚树的建树过程是先将关键点按 d f s dfs 序排序,然后一个一个插入虚树
我个人认为虚树建树的过程模拟起来还是有难度的,这里我就不模拟了(可以看看这里的模拟
),可以直接用这个模板,虚树的题关键在于树形 d p dp

模板

struct edge {
    int nxt, to;
} e[MAX << 1];
int head[MAX], tot;
void add(int u, int v) { e[++tot] = edge{ head[u], v }, head[u] = tot;}

int dep[MAX], fa[MAX], topfa[MAX], siz[MAX], son[MAX], dfn[MAX], cnt;
void dfs(int u, int par) {
    dep[u] = dep[fa[u] = par] + (siz[u] = 1);
    int max_son = -1;
    for (int i = head[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != par) {
            dfs(v, u);
            siz[u] += siz[v];
            if (max_son < siz[v]) son[u] = v, max_son = siz[v];
        }
}
void dfs2(int u, int topf) {
    topfa[u] = topf, dfn[u] = ++cnt;
    if (!son[u]) return;
    dfs2(son[u], topf);
    for (int i = head[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa[u] && v != son[u]) dfs2(v, v);
}
int LCA(int x, int y) {
    while (topfa[x] != topfa[y]) {
        if (dep[topfa[x]] < dep[topfa[y]]) swap(x, y);
        x = fa[topfa[x]];
    }
    return dep[x] < dep[y] ? x : y;
}
//这里边长都是1, 所以就用了dep
int getDis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }


//虚树
int tag[MAX];//tag[u] = 1 <=> 关键点
vector<int> g[MAX];//虚树边
void add_edge(int u, int v) { g[u].push_back(v); }
int st[MAX], top, rt;//rt为虚树根
void insert(int u) {//插入点
    if (top == 1) {
        st[++top] = u;
        return;
    }
    int lca = LCA(u, st[top]);
    if (lca != st[top]) {
        while (top > 1 && dfn[st[top - 1]] >= dfn[lca])
            add_edge(st[top - 1], st[top]), top--;
        if (lca != st[top]) add_edge(lca, st[top]), st[top] = lca;
    }
    st[++top] = u;
}
bool cmp(const int &x, const int &y) { return dfn[x] < dfn[y]; }
void build(vector<int> &v) {//建立虚树
    st[top = 1] = rt;//根节点一定会有
    sort(v.begin(), v.end(), cmp);
    for (auto &i: v) {
        tag[i] = 1;
        if (i != rt) insert(i);
    }
    while (top > 1) add_edge(st[top - 1], st[top]), top--;
}


void dp(int u) {
    //树形dp...
}
void clear(int u) {//清空虚树边和标记, 也可以和dp合并
    for (auto &v: g[u]) clear(v);
    g[u].clear(); tag[u] = 0;
}
void solve() {
    //...
    dp(rt); clear(rt);
    //...
}

int main() {

	//上面读取边...

	//先得到lca, dfs序等等
	dep[0] = -1, rt = 1;//root有时候不一定是1
    dfs(rt, 0); dfs2(rt, rt);

    int Q; scanf("%d", &Q);
    while (Q--) {
        int K; scanf("%d", &K);//读取关键点
        for (int i = 1; i <= K; i++) ....
        //构建虚树
        build(a);
        solve();
    }


	return 0;
}

P2495 [SDOI2011]消耗战

题目链接

题意

N N 个点的树, Q Q 次询问,每次询问最小的代价使得炸毁一些边使得给定的 K i K_i 个点与点 1 1 不连通

题解

先不考虑 Q Q 次询问
显然我们是可以用树形 d p dp 来解决这个问题的
f u f_u 为点 1 1 u u u u 的子树中所有的关键点不连通的最小代价
c o s t u cost_u u u 1 1 之间的最小边权
那么就有: f u = m i n ( c o s t u , f s o n ) f_u = min(cost_u, \sum f_{son})
u u 这个点上面就已经断开,要么 u u 与他的子树 s o n son 断开

这样应该也是好写的
然后我们考虑一下虚树,其实就是建立一颗保留关键点和原树结构的树
所以这个不影响我们 d p dp

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAX = 3e5 + 10;

int N, K, Q;
int a[MAX];

struct edge {
    int nxt, to, w;
} e[MAX << 1];
int head[MAX], tot;
void add(int u, int v, int w) { e[++tot] = edge{ head[u], v, w}, head[u] = tot;}

int dep[MAX], fa[MAX], topfa[MAX], siz[MAX], son[MAX], dfn[MAX], cnt;
ll cost[MAX];

void dfs(int u, int par) {
    dep[u] = dep[fa[u] = par] + (siz[u] = 1);
    int max_son = -1;
    for (int i = head[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != par) {
            cost[v] = min(cost[u], (ll)e[i].w);
            dfs(v, u);
            siz[u] += siz[v];
            if (max_son < siz[v]) son[u] = v, max_son = siz[v];
        }
}

void dfs2(int u, int topf) {
    topfa[u] = topf, dfn[u] = ++cnt;
    if (!son[u]) return;
    dfs2(son[u], topf);
    for (int i = head[u], v; i; i = e[i].nxt)
        if ((v = e[i].to) != fa[u] && v != son[u]) dfs2(v, v);
}

int LCA(int x, int y) {
    while (topfa[x] != topfa[y]) {
        if (dep[topfa[x]] < dep[topfa[y]]) swap(x, y);
        x = fa[topfa[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

vector<int> g[MAX];
void add_edge(int u, int v) { g[u].push_back(v); }
int st[MAX], top, rt;
void insert(int u) {
    if (top == 1) {
        st[++top] = u;
        return;
    }
    int lca = LCA(u, st[top]);
    if (lca != st[top]) {
        while (top > 1 && dfn[st[top - 1]] >= dfn[lca])
            add_edge(st[top - 1], st[top]), top--;
        if (lca != st[top]) add_edge(lca, st[top]), st[top] = lca;
    }
    st[++top] = u;
}

bool cmp(const int &x, const int &y) { return dfn[x] < dfn[y]; }

void build() {
    st[top = 1] = rt;
    sort(a + 1, a + 1 + K, cmp);
    for (int i = 1; i <= K; i++) insert(a[i]);
    while (top > 1) add_edge(st[top - 1], st[top]), top--;
}

//dp过程
ll dp(int u) {
    if (g[u].empty()) return cost[u];
    ll sum = 0;
    for (auto &v: g[u]) sum += dp(v);
    g[u].clear();//直接在dp的过程中清空也行
    return min(cost[u], sum);
}

int main() {

    scanf("%d", &N);
    for (int i = 1; i < N; i++) {
        int u, v, w; scanf("%d%d%d", &u, &v, &w);
        add(u, v, w); add(v, u, w);
    }
    cost[rt = 1] = 9e18;//点1的cost为INF
    dfs(rt, 0); dfs2(rt, rt);
    
    scanf("%d", &Q);
    while (Q--) {
        scanf("%d", &K);
        for (int i = 1; i <= K; i++) scanf("%d", &a[i]);
        build();
        printf("%lld\n", dp(rt));
    }


    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_44282912/article/details/105625593