356 次小生成树(求解最近公共祖先优化)

1. 问题描述:

给定一张 N 个点 M 条边的无向图,求无向图的严格次小生成树。设最小生成树的边权之和为 sum,严格次小生成树就是指边权之和大于 sum 的生成树中最小的一个。

输入格式

第一行包含两个整数 N 和 M。接下来 M 行,每行包含三个整数 x,y,z,表示点 x 和点 y 之前存在一条边,边的权值为 z。

输出格式

包含一行,仅一个数,表示严格次小生成树的边权和。(数据保证必定存在严格次小生成树)

数据范围

N ≤ 10 ^ 5,M ≤ 3 × 10 ^ 5

输入样例:

5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6

输出样例:

11
来源:https://www.acwing.com/problem/content/description/358/

2. 思路分析:

次小生成树有一个定理:对于一个无向图,如果存在最小生成树与(严格)次小生成树,那么对于任意一棵最小生成树都存在一棵(严格)次小生成树使得这两棵树只有一条边不同。这道题目与1148题是一样的,但是这道题目的数据范围比较大,我们需要在优化一下1148题中严格次小生成树的求解方法,1148题求解严格次小生成树有三个步骤:

  • 使用kruskal算法求解最小生成树,并且在存储所有边的信息中标记是哪些边是树边,在求解最小生成树的过程中将最小生成树中构建出来,也即存储最小生成树中所有边的信息
  • 因为求解的是严格次小生成树,也即边权之和比最小生成树的要大但是最小的那个(边权之和第二小),所以需要维护最小生成树中任意两点之间的最大边权和次大边权
  • 枚举所有非树边,求解使用当前非树边替换最小生成树中边的最小值:min(sum + w - dis[a][b]),因为求解的是严格次小生成树,所以需要满足w - dis[a][b]的条件 

这里选择优化第二个求解的步骤:我们可以在使用倍增思想求解两个点的最近公共祖先的时候求解两点之间路径的最大边权和次大边权,其中在求解LCA的时候除了维护fa数组之外,还需要维护两个数组分别为d1,d2,d1(i,j)表示从节点i向上跳2 ^ j步路径上的最大边权,d2(i,j)表示从节点i往上跳2 ^ j步路径上的次大边权,维护这两个数组其实与维护最近最近公共祖先中fa数组的思想是一样的,从当前节点i往上跳2 ^ j步分为两步来跳,第一步是从当前节点i跳2 ^ (j - 1)步到达某个祖先节点的过程维护路径的最大边权和次大边权,第二步是从第一个跳到的祖先节点继续往上跳2 ^ (j - 1)步维护路径的最大边权和次大边权,我们在跳的时候其实是枚举可以跳的步数k,0 <= k <= 16,所以从节点i往上跳2 ^ k步的最大边权和次大边权在这四个值里面选即可。当我们需要求解当前节点a和b之间的最大边权和次大边权的时候需要求解LCA,先预处理出fa,d1,d2数组,然后使用一个方法求解LCA,首先我们需要将深度较深的节点跳到深度较浅的节点位于同一深度,也即两个节点位于同一层,在其中一个节点跳的过程中记录下路径上的最大边权和次大边权,然后两个节点往上跳直到fa(a, k) == fa(b, k)的过程中分别记录下两个节点往上跳的路径的最大边权和次大边权,将最大边权和次大边权记录在distance中,最后枚举一下distance,找到最大边权和次大边权,当非树边可以替换为最大边权的时候那么直接替换,否则替换为次大边权,也即满足w - dis[a][b] > 0。

3. 代码如下:

c++代码(y总):

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 100010, M = 300010, INF = 0x3f3f3f3f;

int n, m;
struct Edge
{
    int a, b, w;
    bool used;
    bool operator< (const Edge &t) const
    {
        return w < t.w;
    }
}edge[M];
int p[N];
int h[N], e[M], w[M], ne[M], idx;
int depth[N], fa[N][17], d1[N][17], d2[N][17];
int q[N];

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

int find(int x)
{
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

LL kruskal()
{
    for (int i = 1; i <= n; i ++ ) p[i] = i;
    sort(edge, edge + m);
    LL res = 0;
    for (int i = 0; i < m; i ++ )
    {
        int a = find(edge[i].a), b = find(edge[i].b), w = edge[i].w;
        if (a != b)
        {
            p[a] = b;
            res += w;
            edge[i].used = true;
        }
    }

    return res;
}

void build()
{
    memset(h, -1, sizeof h);
    for (int i = 0; i < m; i ++ )
        if (edge[i].used)
        {
            int a = edge[i].a, b = edge[i].b, w = edge[i].w;
            add(a, b, w), add(b, a, w);
        }
}

void bfs()
{
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0, depth[1] = 1;
    q[0] = 1;
    int hh = 0, tt = 0;
    while (hh <= tt)
    {
        int t = q[hh ++ ];
        for (int i = h[t]; ~i; i = ne[i])
        {
            int j = e[i];
            if (depth[j] > depth[t] + 1)
            {
                depth[j] = depth[t] + 1;
                q[ ++ tt] = j;
                fa[j][0] = t;
                d1[j][0] = w[i], d2[j][0] = -INF;
                for (int k = 1; k <= 16; k ++ )
                {
                    int anc = fa[j][k - 1];
                    fa[j][k] = fa[anc][k - 1];
                    int distance[4] = {d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]};
                    d1[j][k] = d2[j][k] = -INF;
                    for (int u = 0; u < 4; u ++ )
                    {
                        int d = distance[u];
                        if (d > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d;
                        else if (d != d1[j][k] && d > d2[j][k]) d2[j][k] = d;
                    }
                }
            }
        }
    }
}

int lca(int a, int b, int w)
{
    static int distance[N * 2];
    int cnt = 0;
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 16; k >= 0; k -- )
        if (depth[fa[a][k]] >= depth[b])
        {
            distance[cnt ++ ] = d1[a][k];
            distance[cnt ++ ] = d2[a][k];
            a = fa[a][k];
        }
    if (a != b)
    {
        for (int k = 16; k >= 0; k -- )
            if (fa[a][k] != fa[b][k])
            {
                distance[cnt ++ ] = d1[a][k];
                distance[cnt ++ ] = d2[a][k];
                distance[cnt ++ ] = d1[b][k];
                distance[cnt ++ ] = d2[b][k];
                a = fa[a][k], b = fa[b][k];
            }
        distance[cnt ++ ] = d1[a][0];
        distance[cnt ++ ] = d1[b][0];
    }

    int dist1 = -INF, dist2 = -INF;
    for (int i = 0; i < cnt; i ++ )
    {
        int d = distance[i];
        if (d > dist1) dist2 = dist1, dist1 = d;
        else if (d != dist1 && d > dist2) dist2 = d;
    }

    if (w > dist1) return w - dist1;
    if (w > dist2) return w - dist2;
    return INF;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i ++ )
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        edge[i] = {a, b, c};
    }

    LL sum = kruskal();
    build();
    bfs();

    LL res = 1e18;
    for (int i = 0; i < m; i ++ )
        if (!edge[i].used)
        {
            int a = edge[i].a, b = edge[i].b, w = edge[i].w;
            res = min(res, sum + lca(a, b, w));
        }
    printf("%lld\n", res);

    return 0;
}

python代码(超时了呜呜呜),后面优化了一下但是最大的那个数据还是没过,只过了10个数据。python读入数据和初始化比较大的列表的时候耗时都比较大。

import collections
from typing import List


class Solution:
    def find(self, x: int, p: List[int]):
        if x != p[x]:
            p[x] = self.find(p[x], p)
        return p[x]

    def kruskal(self, p: List[int], w: List[List[int]]):
        # e存储最小生成树的边的信息
        e = collections.defaultdict(list)
        # 按照边权排序
        w.sort(key=lambda x: x[2])
        s = 0
        for i in range(len(w)):
            a, b, c = self.find(w[i][0], p), self.find(w[i][1], p), w[i][2]
            if a != b:
                p[a] = b
                # 标记为最小生成树的树边
                w[i][3] = 1
                # 计算当前累加上当前边之后最小生成树的权值之和
                s += c
                # 将边的信息存储到最小生成树e中
                e[w[i][0]].append([w[i][1], c])
                e[w[i][1]].append([w[i][0], c])
        return s, e

    # 预处理fa, depth, d1, d2
    def bfs(self, fa: List[List[int]], depth: List[int], d1: List[List[int]], d2: List[List[int]],
            g: collections.defaultdict):
        # 任选一个点作为根节点, 这里将1号点作为根节点
        q = collections.deque([1])
        depth[0], depth[1] = 0, 1
        INF = 10 ** 18
        while q:
            p = q.popleft()
            for next in g[p]:
                if depth[next[0]] > depth[p] + 1:
                    depth[next[0]] = depth[p] + 1
                    q.append(next[0])
                    # 预处理fa和d1和d2
                    j = next[0]
                    fa[j][0], d1[j][0], d2[j][0] = p, next[1], -INF
                    for k in range(1, 17):
                        anc = fa[j][k - 1]
                        fa[j][k] = fa[anc][k - 1]
                        # 分为两段求解最大值与次大值(跳两步), 从节点i往上跳2 ^ k步的最大值与次大值肯定是这个四个值中求解, 在这四段中求解最大边权和次大边权
                        d1[j][k], d2[j][k] = -INF, -INF
                        distance = [d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]]
                        for i in range(4):
                            d = distance[i]
                            if d > d1[j][k]:
                                d2[j][k] = d1[j][k]
                                d1[j][k] = d
                            # 注意不能够等于d1[j][k]
                            elif d != d1[j][k] and d > d2[j][k]:
                                d2[j][k] = d

    # 在求解LCA的过程中维护两点之间的最大边权和次大边权, 遍历的是最小生成树的节点
    def lca(self, a: int, b: int, c: int, fa: List[List[int]], depth: List[int], d1: List[List[int]], d2: List[List[int]]):
        # 确保a的深度比b的深度大
        if depth[a] < depth[b]: a, b = b, a
        distance = list()
        # a节点跳到与b节点同一深度, 从高到低开始枚举
        for k in range(16, -1, -1):
            if depth[fa[a][k]] >= depth[b]:
                # 记录a往上跳的过程的最大边权与次大边权
                distance.append(d1[a][k])
                distance.append(d2[a][k])
                a = fa[a][k]
        # 当两个节点不同的时候才执行操作
        if a != b:
            for k in range(16, -1, -1):
                if fa[a][k] != fa[b][k]:
                    # 记录两个节点往上跳的最大边权和次大边权
                    distance.append(d1[a][k])
                    distance.append(d2[a][k])
                    distance.append(d1[b][k])
                    distance.append(d2[b][k])
                    a, b = fa[a][k], fa[b][k]
            # 记录最后跳一步的最大边权(只有一条边了所以是最大边权)
            distance.append(d1[a][0])
            distance.append(d1[b][0])
        INF = 10 ** 10
        dis1, dis2 = -INF, -INF
        # 枚举最大边权和次大边权
        for x in distance:
            if x > dis1:
                dis2, dis1 = dis1, x
            elif x != dis1 and x > dis2:
                dis2 = x
        # 可以替换那么直接替换为最大边权否则尝试替换为次大边权, 都不能替换返回INF
        if c > dis1: return c - dis1
        if c > dis2: return c - dis2
        return INF

    def process(self):
        # n个点m条边
        n, m = map(int, input().split())
        w = list()
        for i in range(m):
            a, b, c = map(int, input().split())
            # 第四个参数为是否是非树边的标记
            w.append([a, b, c, 0])
        p = [i for i in range(n + 10)]
        # 标记最小生成树的边并且创建最小生成树的边, e表示最小生成树中的信息, kruskal返回两个参数: 第一个是最小生成树的边权之和, 第二个是最小生成树的边权信息
        s, e = self.kruskal(p, w)
        # 接下来是预处理depth和fa列表
        INF = 10 ** 18
        fa, depth = [[0] * 17 for i in range(n + 10)], [INF] * (n + 10)
        d1, d2 = [[0] * 17 for i in range(n + 10)], [[0] * 17 for i in range(n + 10)]
        self.bfs(fa, depth, d1, d2, e)
        # 枚举非树边然后将其中的最小生成树的边换成是非树边
        res = 10 ** 18
        for i in range(m):
            a, b, c = w[i][0], w[i][1], w[i][2]
            if w[i][3] == 0:
                res = min(res, s + self.lca(a, b, c, fa, depth, d1, d2))
        return res


if __name__ == "__main__":
    print(Solution().process())

おすすめ

転載: blog.csdn.net/qq_39445165/article/details/121233831