百度之星 初赛三 最短路 2 Dijkstra

Code: 

#include <bits/stdc++.h>
#define inf 100000000000000 
#define ll long long  
#define mod 998244353   
#define N 1003 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
int n,m,s;     
struct Node
{
    int u;
    ll dis;
    Node(int u=0,ll dis=0):u(u),dis(dis){}  
    bool operator<(Node b)const
    {
        return b.dis<dis;  
    }
}; 
priority_queue<Node>q;    
int hd[N],nex[N<<2],pre[N<<2],edges,done[N],to[N<<2],now[N<<2];       
ll f[N][N],d[N];          
ll val[N<<2];  
inline void addedge(int u,int v,ll c)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c; 
}
inline void Dijkstra()
{
    int i,u,v;
    for(i=1;i<=n;++i) done[i]=0;   
    for(i=0;i<=n;++i) d[i]=inf,pre[i]=0,now[i]=i;   
    d[s]=0, q.push(Node(s,0)), pre[s]=0,now[s]=0;  
    while(!q.empty())
    {
        Node e=q.top(); u=e.u,q.pop();
        if(done[u]) continue; 
        done[u]=1;        
        if(u!=s) 
        {
            now[u]=max(u, pre[u]);   
        }
        for(i=hd[u];i;i=nex[i]){
            if(d[to[i]]>=d[u]+val[i]){
                if(d[to[i]]>d[u]+val[i]) 
                {
                    d[to[i]]=d[u]+val[i]; 
                    pre[to[i]]=now[u];
                    q.push(Node(to[i],d[to[i]]));   
                }
                else {
                    if(now[u]<pre[to[i]]) pre[to[i]]=now[u];    
                  }
            }
        }
    }               
}   
int main() {
    setIO("input");
    using namespace IO;  
    int T; 
    scanf("%d",&T); 
    while(T--) { 
        int i,j; 
        scanf("%d%d",&n,&m); 
        edges=0; 
        for(i=1;i<=n;++i) hd[i]=0;   
        for(i=1;i<=n;++i) for(j=1;j<=n;++j) f[i][j]=inf;  
        for(i=1;i<=n;++i) f[i][i]=0; 
        for(i=1;i<=m;++i) {
            int a,b,c; 
            scanf("%d%d%d",&a,&b,&c);   
            addedge(a,b,(ll)c), addedge(b,a,(ll)c), f[a][b]=f[b][a]=min(f[a][b],(ll)c);   
        }  
        int ans=0; 
        for(i=1;i<=n;++i) { 
            s=i, Dijkstra(); 
            for(j=1;j<=n;++j) 
                 ans=(long long) (ans+pre[j])%mod;               
        } 
        printf("%d\n",ans);   

    }
    return 0; 
}

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/11405249.html
今日推荐