368 银河(强连通分量)

1. 问题描述:

银河中的恒星浩如烟海,但是我们只关注那些最亮的恒星。我们用一个正整数来表示恒星的亮度,数值越大则恒星就越亮,恒星的亮度最暗是 1。现在对于 N 颗我们关注的恒星,有 M 对亮度之间的相对关系已经判明。你的任务就是求出这 N 颗恒星的亮度值总和至少有多大。

输入格式

第一行给出两个整数 N 和 M。之后 M 行,每行三个整数 T,A,B,表示一对恒星 (A,B) 之间的亮度关系。恒星的编号从 1 开始。
如果 T=1,说明 A 和 B 亮度相等。 
如果 T=2,说明 A 的亮度小于 B 的亮度。 
如果 T=3,说明 A 的亮度不小于 B 的亮度。 
如果 T=4,说明 A 的亮度大于 B 的亮度。 
如果 T=5,说明 A 的亮度不大于 B 的亮度。

输出格式

输出一个整数表示结果。若无解,则输出 −1。

数据范围

N ≤ 100000,M ≤ 100000

输入样例:

5 7 
1 1 2 
2 3 2 
4 4 1 
3 4 5 
5 4 5 
2 3 5 
4 5 1 

输出样例:

11
来源:https://www.acwing.com/problem/content/description/370/

2. 思路分析:

分析题目可以知道这道题目与1169题是一模一样的,题目中五个不同的条件限制可以看成是五个不等式的限制,所以可以使用差分约束来解决这个问题(并且需要一个点可以到达所有边和要有一个绝对的限制),因为求解的是最小值,所以需要使用单源最长路径来解决,那么需要将所有的不等式约束转换为大于等于号的形式,也即将其转换为xi >= xj + ck ,转换成这样的形式之后在建图的时候不等号右边节点xj向左边节点xi连一条权重为ck的边,可以根据下面得到的不等式建图即可:

  • T = 1,A = B === > A >= B ,B >= A
  • T = 2,A < B === > B >= A + 1
  • T = 3,A >= B  === > A >= B
  • T = 4,A > B === > A >= B + 1
  • T = 5,A <= B === > B >= A

使用差分约束思路求解的时候先判断图中是否存在正环,如果存在正环说明无解,输出-1,当不存在正环的时候那么求解从源点到各个点的最长路径;这道题目除了使用差分约束的思路求解之外由于这道题目的图比较特殊,所有边权都是大于等于0的,所以可以使用强连通分量求解,但是并不是所有的差分约束问题都可以使用强连通分量来求解;我们可以使用tarjan算法求解出所有的强连通分量,然后缩点,如果发现属于同一个强连通分量中存在大于0的边,因为强连通分量中任意两点可以相互到达,而且所有边权都是大于等于0的所以肯定存在正环,直接输出-1即可,否则一个强连通分量向另外一个强连通分量连一条权重为ck的边,由于tarjan算法求解所有强连通分量的时候每一个强连通分量都是按照逆序的顺序排序的,所以逆序枚举所有的强连通分量,递推求解最长路即可,将每一个强连通分量看成是一个点,最后枚举一下dis数组求解答案即可。

3. 代码如下:

c++代码(y总):

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long LL;

const int N = 100010, M = 600010;

int n, m;
int h[N], hs[N], e[M], ne[M], w[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int id[N], scc_cnt, sz[N];
int dist[N];

void add(int h[], int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

void tarjan(int u)
{
    dfn[u] = low[u] = ++ timestamp;
    stk[ ++ top] = u, in_stk[u] = true;

    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (!dfn[j])
        {
            tarjan(j);
            low[u] = min(low[u], low[j]);
        }
        else if (in_stk[j]) low[u] = min(low[u], dfn[j]);
    }

    if (dfn[u] == low[u])
    {
        ++ scc_cnt;
        int y;
        do {
            y = stk[top -- ];
            in_stk[y] = false;
            id[y] = scc_cnt;
            sz[scc_cnt] ++ ;
        } while (y != u);
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    memset(h, -1, sizeof h);
    memset(hs, -1, sizeof hs);

    for (int i = 1; i <= n; i ++ ) add(h, 0, i, 1);

    while (m -- )
    {
        int t, a, b;
        scanf("%d%d%d", &t, &a, &b);
        if (t == 1) add(h, b, a, 0), add(h, a, b, 0);
        else if (t == 2) add(h, a, b, 1);
        else if (t == 3) add(h, b, a, 0);
        else if (t == 4) add(h, b, a, 1);
        else add(h, a, b, 0);
    }

    tarjan(0);

    bool success = true;
    for (int i = 0; i <= n; i ++ )
    {
        for (int j = h[i]; ~j; j = ne[j])
        {
            int k = e[j];
            int a = id[i], b = id[k];
            if (a == b)
            {
                if (w[j] > 0)
                {
                    success = false;
                    break;
                }
            }
            else add(hs, a, b, w[j]);
        }
        if (!success) break;
    }

    if (!success) puts("-1");
    else
    {
        for (int i = scc_cnt; i; i -- )
        {
            for (int j = hs[i]; ~j; j = ne[j])
            {
                int k = e[j];
                dist[k] = max(dist[k], dist[i] + w[j]);
            }
        }

        LL res = 0;
        for (int i = 1; i <= scc_cnt; i ++ ) res += (LL)dist[i] * sz[i];

        printf("%lld\n", res);
    }

    return 0;
}

python代码:python使用tarjan算法求解强连通分量的时候由于数据规模比较大所以递归的时候会出问题,只过了7个数据:

from typing import List
import sys

class Solution:
    # 定义tarjan算法中需要使用到的各个全局变量
    stk, in_stk, idx, timestamp, size, top, ssc_cnt = None, None, None, None, None, None, None

    def tarjan(self, u: int, dfn: List[int], low: List[int], g: List[List[int]]):
        dfn[u] = low[u] = self.timestamp + 1
        self.timestamp += 1
        self.stk[self.top + 1] = u
        self.top += 1
        self.in_stk[u] = 1
        for next in g[u]:
            if dfn[next[0]] == 0:
                self.tarjan(next[0], dfn, low, g)
                low[u] = min(low[u], low[next[0]])
            elif self.in_stk[next[0]] == 1:
                low[u] = min(low[u], dfn[next[0]])
        if dfn[u] == low[u]:
            self.ssc_cnt += 1
            while True:
                t = self.stk[self.top]
                self.top -= 1
                self.in_stk[t] = 0
                self.idx[t] = self.ssc_cnt
                self.size[self.ssc_cnt] += 1
                if t == u: break

    def process(self):
        n, m = map(int, input().split())
        g1 = [list() for i in range(n + 10)]
        for i in range(m):
            # 建图过程
            t, a, b = map(int, input().split())
            if t == 1:
                g1[b].append((a, 0))
                g1[a].append((b, 0))
            elif t == 2:
                g1[a].append((b, 1))
            elif t == 3:
                g1[b].append((a, 0))
            elif t == 4:
                g1[b].append((a, 1))
            else:
                g1[a].append((b, 0))
        # 0号点作为超级源点向其余点的点连一条权重为1的边
        for i in range(1, n + 1):
            g1[0].append((i, 1))
        # tarjan算法求解所有强连通分量
        dfn, low = [0] * (n + 10), [0] * (n + 10)
        self.stk, self.in_stk, self.size, self.idx = [0] * (n + 10), [0] * (n + 10), [0] * (n + 10), [0] * (n + 10)
        self.top = self.timestamp = self.ssc_cnt = 0
        # 0号点作为起点
        self.tarjan(0, dfn, low, g1)
        # 缩点, 建图的过程, 新图存储到g2中
        g2 = [list() for i in range(self.ssc_cnt + 10)]
        success = 1
        for i in range(n + 1):
            for next in g1[i]:
                a, b = self.idx[i], self.idx[next[0]]
                # 两个点属于同一个强连通分量并且里面存在边大于0所以肯定是正环
                if a == b:
                    if next[1] > 0:
                        success = 0
                        break
                else:
                    g2[a].append((b, next[1]))
            if success == 0: break
        if success == 0: print("-1")
        else:
            dis = [0] * (self.ssc_cnt + 10)
            for i in range(self.ssc_cnt, 0, -1):
                for next in g2[i]:
                    # 更新最长路
                    if dis[next[0]] < dis[i] + next[1]:
                        dis[next[0]] = dis[i] + next[1]
            res = 0
            # 计算每一个强连通分量的距离与点的乘积之和
            for i in range(1, self.ssc_cnt + 1):
                res += dis[i] * self.size[i]
            print(res)


if __name__ == "__main__":
    # 设置最大递归调用次数
    sys.setrecursionlimit(50000)
    Solution().process()

Guess you like

Origin blog.csdn.net/qq_39445165/article/details/121320115