1175 最大半连通子图(强连通分量)

1. 问题描述:

一个有向图 G=(V,E) 称为半连通的 (Semi-Connected),如果满足:∀u,v∈V,满足 u→v 或 v→u,即对于图中任意两点 u,v,存在一条 u 到 v 的有向路径或者从 v 到 u 的有向路径。若 G′=(V′,E′) 满足,E′ 是 E 中所有和 V′ 有关的边,则称 G′ 是 G 的一个导出子图。若 G′ 是 G 的导出子图,且 G′ 半连通,则称 G′ 为 G 的半连通子图。若 G′ 是 G 所有半连通子图中包含节点数最多的,则称 G′ 是 G 的最大半连通子图。给定一个有向图 G,请求出 G 的最大半连通子图拥有的节点数 K,以及不同的最大半连通子图的数目 C。由于 C 可能比较大,仅要求输出 C 对 X 的余数。

输入格式

第一行包含三个整数 N,M,X。N,M 分别表示图 G 的点数与边数,X 的意义如上文所述;接下来 M 行,每行两个正整数 a,b,表示一条有向边 (a,b)。图中的每个点将编号为 1 到 N,保证输入中同一个 (a,b) 不会出现两次。

输出格式

应包含两行。第一行包含一个整数 K,第二行包含整数 C mod X。

数据范围

1 ≤ N ≤ 10 ^ 5,
1 ≤ M ≤ 10 ^ 6,
1 ≤ X ≤ 10 ^ 8

输入样例:

6 6 20070603
1 2
2 1
1 3
2 4
5 6
6 4

输出样例:

3
3
来源:https://www.acwing.com/problem/content/description/1177/

2. 思路分析:

分析题目可以知道半连通分量指的是节点u可以走到v或者节点v可以走到u,也即至少有一个是成立的,我们需要求解最大半连通子图的节点个数以及对应的方案数目,如果我们在一个强连通分量中,可以发现选择强连通分量中的某些点不如选择整个强连通分量中的所有点,因为强连通分量中的所有点都是连通的,所以选择全部点更优,我们可以先把所有强连通分量找出来,求解强连通分量可以使用tarjan算法,也即在dfs遍历的过程中找出所有的强连通分量,将每一个强联通分量中点的编号标记在对应的强连通分量中编号中(使用idx来记录每一个点所在的强连通分量),并且记录每一个强连通分量中点的个数,求解完强连通分量之后接下来就是缩点,因为这道题目需要求解最大半连通子图中点的个数以及对应的方案数目,而我们使用tarjan算法求解强连通分量,缩点以及建图之后那么得到的是一个有向无环图,也即新建的图是满足拓扑序的,可以发现本质上求解的是拓扑图中的最长链,而且图是无环的所以我们可以使用递推的方式来求解,并且因为求解我们的是最长链的方案数目,所以类似于之前求解最优方案的思路,我们可以使用一个数组g来记录方案数目,在求解最长链的时候维护这个方案数目。

这里需要注意一些细节上的问题,因为在递推是按照边的方式进行递推的,也即我们枚举的是边,通过边来转移的,所以在缩点之后的建图需要注意不要重复建边(也即一个强连通分量中的一条边只能够向另外个强连通分量连一条边),因为如果有重边的话会多计算一些方案数目,题目中只有点不同的时候才属于不同的方案,如下图所示:

 当我们建好图之后那么递推求解最长链即可,因为需要求解最长链对应的方案数目,所以需要维护两个数组f和g,f[i]维护到强连通分量编号为i的最大节点个数,g[i]维护到强联通分量编号为i的最大节点个数对应的方案数目,枚举的时候按照强连通分量编号逆序枚举,为什么呢??这个其实与tarjan算法求解每一个强连通分量有关,如下图所示,可以发现在使用tarjan算法求解之后每一个强连通分量都是按照逆序的顺序排序的,最终最长链肯定是在f数组的第一个位置,所以需要使用逆序的顺序进行递推,这里求解最长链的思路与背包问题求解方案数目的思路是一样的,分为两种情况更新,看当前的点可以更新哪些邻接点:

  • f[next] > f[i] + size[k],f[next] = f[i] + size[k],g[next] = g[i] 
  • f[next] = f[i] + size[k],g[next] += g[i]

 

3. 代码如下:

c++(y总):

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_set>

using namespace std;

typedef long long LL;

const int N = 100010, M = 2000010;

int n, m, mod;
int h[N], hs[N], e[M], ne[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int id[N], scc_cnt, scc_size[N];
int f[N], g[N];

void add(int h[], int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void tarjan(int u)
{
    dfn[u] = low[u] = ++ timestamp;
    stk[ ++ top] = u, in_stk[u] = true;

    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (!dfn[j])
        {
            tarjan(j);
            low[u] = min(low[u], low[j]);
        }
        else if (in_stk[j]) low[u] = min(low[u], dfn[j]);
    }

    if (dfn[u] == low[u])
    {
        ++ scc_cnt;
        int y;
        do {
            y = stk[top -- ];
            in_stk[y] = false;
            id[y] = scc_cnt;
            scc_size[scc_cnt] ++ ;
        } while (y != u);
    }
}

int main()
{
    memset(h, -1, sizeof h);
    memset(hs, -1, sizeof hs);

    scanf("%d%d%d", &n, &m, &mod);
    while (m -- )
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(h, a, b);
    }

    for (int i = 1; i <= n; i ++ )
        if (!dfn[i])
            tarjan(i);

    unordered_set<LL> S;    // (u, v) -> u * 1000000 + v
    for (int i = 1; i <= n; i ++ )
        for (int j = h[i]; ~j; j = ne[j])
        {
            int k = e[j];
            int a = id[i], b = id[k];
            LL hash = a * 1000000ll + b;
            if (a != b && !S.count(hash))
            {
                add(hs, a, b);
                S.insert(hash);
            }
        }

    for (int i = scc_cnt; i; i -- )
    {
        if (!f[i])
        {
            f[i] = scc_size[i];
            g[i] = 1;
        }
        for (int j = hs[i]; ~j; j = ne[j])
        {
            int k = e[j];
            if (f[k] < f[i] + scc_size[k])
            {
                f[k] = f[i] + scc_size[k];
                g[k] = g[i];
            }
            else if (f[k] == f[i] + scc_size[k])
                g[k] = (g[k] + g[i]) % mod;
        }
    }

    int maxf = 0, sum = 0;
    for (int i = 1; i <= scc_cnt; i ++ )
        if (f[i] > maxf)
        {
            maxf = f[i];
            sum = g[i];
        }
        else if (f[i] == maxf) sum = (sum + g[i]) % mod;

    printf("%d\n", maxf);
    printf("%d\n", sum);

    return 0;
}

python(由于数据规模太大了,所以递归的深度很大所以超时了,只过了8个数据):

from typing import List
import sys


class Solution:
    stk, in_stk, size, idx, timestamp, top, ssc_cnt = None, None, None, None, None, None, None

    # tarjan算法求解强连通分量
    def tarjan(self, u: int, dfn: List[int], low: List[int], g: List[List[int]]):
        dfn[u] = low[u] = self.timestamp + 1
        self.timestamp += 1
        self.stk[self.top + 1] = u
        self.top += 1
        self.in_stk[u] = 1
        for next in g[u]:
            if dfn[next] == 0:
                self.tarjan(next, dfn, low, g)
                low[u] = min(low[u], low[next])
            elif self.in_stk[next] == 1:
                low[u] = min(low[u], dfn[next])
        if dfn[u] == low[u]:
            self.ssc_cnt += 1
            while True:
                t = self.stk[self.top]
                self.top -= 1
                self.in_stk[t] = 0
                self.idx[t] = self.ssc_cnt
                self.size[self.ssc_cnt] += 1
                if t == u: break

    def process(self):
        n, m, mod = map(int, input().split())
        # g1表示原图
        g1 = [list() for i in range(n + 10)]
        for i in range(m):
            a, b = map(int, input().split())
            g1[a].append(b)
        dfn, low = [0] * (n + 10), [0] * (n + 10)
        self.stk, self.in_stk, self.idx, self.size = [0] * (n + 10), [0] * (n + 10), [0] * (n + 10), [0] * (n + 10)
        self.timestamp = self.ssc_cnt = self.top = 0
        for i in range(1, n + 1):
            if dfn[i] == 0:
                self.tarjan(i, dfn, low, g1)
        # 缩点, 建图
        dic = dict()
        # g2为新建的图
        g2 = [list() for i in range(self.ssc_cnt + 1)]
        for i in range(1, n + 1):
            for next in g1[i]:
                a, b = self.idx[i], self.idx[next]
                hash = a * 10000000 + b
                # 如果当前连的边之后是没有连过那么则建一条边
                if self.idx[i] != self.idx[next] and hash not in dic:
                    g2[a].append(b)
                    dic[hash] = 1
        f, g = [0] * (self.ssc_cnt + 1), [0] * (self.ssc_cnt + 1)
        # 按照拓扑序逆推
        for i in range(self.ssc_cnt, 0, -1):
            if f[i] == 0:
                f[i] = self.size[i]
                g[i] = 1
            for next in g2[i]:
                if f[next] < f[i] + self.size[next]:
                    f[next] = f[i] + self.size[next]
                    g[next] = g[i]
                elif f[next] == f[i] + self.size[next]:
                    f[next] = f[i] + self.size[next]
                    g[next] = (g[next] + g[i]) % mod
        maxf = s = 0
        # 枚举答案以及方案数目
        for i in range(1, self.ssc_cnt + 1):
            if f[i] > maxf:
                maxf = f[i]
                s = g[i]
            elif f[i] == maxf:
                s = (s + g[i]) % mod
        print(maxf)
        print(s)


if __name__ == "__main__":
    sys.setrecursionlimit(50100)
    Solution().process()

Guess you like

Origin blog.csdn.net/qq_39445165/article/details/121304893