图论训练之七

https://nanti.jisuanke.com/t/A1108

本题又叫缺点最短路,数据卡的很好,

一N×N×N×N恰好过不了

N×N×N×logN才行

如果一的话就可以再在floyed的基础上多枚举一维

这一维表示不经过该点

floyed的本质是一个增量算法,最外一维枚举的是k,但这个顺序并不影响最后的结果

如果可以处理处对于每个点Y只剩Y没在floyed的转移矩阵里

这个矩阵的值就是不经过 y 点的全源最短路

考虑分治

为什么要分治呢

因为一算法的不好在于每次排除一个点都要所有都枚举一次,很多状态都是重复计算的,

比如总共有10个点,先排除点1,再排除点2,明显两次其他八个点都只用算一次就行了,而一算法,就会多次计算

分治,可以保证每个只算一次,要用的时候就把其他八个拼起来(因为顺序不重要

每一次把点集拆成两半,

每一次把点集拆成两半,

先用前一半的点在 Floyd 算法中滚,再递归后一半点。

然后回溯,用后一半的点在 Floyd 算法里滚,递归前一半的点。

这样每个只有一个点的状态得到的就是只有这个点没有在 Floyd 算法里滚的矩阵

如果实在理解不到的话,代码里有个注释打开,

4
0 1 -1 -1
-1 0 1 -1
-1 -1 0 1
1 -1 -1 0

输出是

l=1 r=4
l=3 r=4
l=4 r=4
l=3 r=3
l=1 r=2
l=2 r=2
l=1 r=1
4

code by std:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <queue>
#include <vector> 
using namespace std;
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))

inline void read(int &x)
{
    x = 0;char ch = getchar(), c = ch;
    while(ch < '0' || ch > '9')c = ch, ch = getchar();
    while(ch <= '9' && ch >= '0')x = x * 10 + ch - '0', ch = getchar();
    if(c == '-')x = -x;
}

const int INF = 0x3f3f3f3f;
const int MAXN = 300 + 10;

int g[MAXN][MAXN],n;
long long ans;

void solve(int l, int r)
{
    //cout<<"l="<<l<<" r="<<r<<endl;
    if(l == r)
    {
        for(register int i = 1;i <= n;++ i)
        {
            if(l == i) continue;
            for(register int j = 1;j <= n;++ j)
            {
                if(r == j) continue;
                if(g[i][j] != INF)
                    ans += g[i][j];
                else
                    -- ans;
            }
        }
        return;
    }
    int tmp[MAXN][MAXN];
    for(register int i = 1;i <= n;++ i)
        for(register int j = 1;j <= n;++ j)
            tmp[i][j] = g[i][j];
    int mid = (l + r) >> 1;
    for(register int k = l;k <= mid;++ k)
        for(register int i = 1;i <= n;++ i)
            for(register int j = 1;j <= n;++ j)
                if(g[i][j] > g[i][k] + g[k][j])
                    g[i][j] = g[i][k] + g[k][j];
    solve(mid + 1, r);
    for(register int i = 1;i <= n;++ i)
        for(register int j = 1;j <= n;++ j)
            g[i][j] = tmp[i][j];
    for(register int k = mid + 1;k <= r;++ k)
        for(register int i = 1;i <= n;++ i)
            for(register int j = 1;j <= n;++ j)
                if(g[i][j] > g[i][k] + g[k][j])
                    g[i][j] = g[i][k] + g[k][j];
    solve(l, mid);
    return;
}

int main()
{
    read(n);
    for(register int i = 1;i <= n;++ i)
        for(register int j = 1;j <= n;++ j)
        {
            read(g[i][j]);
            if(g[i][j] == -1)g[i][j] = INF;
        }
    solve(1, n);
    printf("%lld", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/wzxbeliever/p/11639295.html