[HAOI2012]道路,洛谷P2505,最短路图

版权声明:因为我是蒟蒻,所以请大佬和神犇们不要转载(有坑)的文章,并指出问题,谢谢 https://blog.csdn.net/Deep_Kevin/article/details/83867711

正题

      这题还是挺好想的。

      把每个点作为起点的最短路图建出来。

      做一次拓扑排序,求起点到该点有多少条最短路图。

      然后做一次反拓扑序,求出该点可以到达其他点的路径种数。

      最后对于边(u,v),它的价值就是u的第一个价值乘上v的第二个价值。

      相当于算的是以i为起点的最短路有多少条经过这条边。答案全部加起来就可以了。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<queue>
using namespace std;

int n,m;
struct edge{
	int x,y,c,next;
}p[5010],s[5010];
int fir[3010],first[3010];//fir is for SPFA , the other is for TP
int len;
long long ans[5010];
queue<int> f;
int dis[3010],in[3010];
long long cnt1[3010],cnt2[3010];
int tp[3010];
bool tf[3010],insert[5010];
long long mod=1e9+7;

void ins(int x,int y,int c){
	len++;
	p[len]=(edge){x,y,c,fir[x]};fir[x]=len;
}

void inss(int x,int y,int c){
	len++;
	s[len]=(edge){x,y,c,first[x]};first[x]=len;
}

void SPFA(int x){
	f.push(x);	
	tf[x]=true;
	memset(dis,63,sizeof(dis));dis[x]=0;
	while(!f.empty()){
		int now=f.front();
		tf[now]=false;f.pop();
		for(int i=fir[now];i!=0;i=p[i].next){
			int y=p[i].y;
			if(dis[y]>dis[now]+p[i].c){
				dis[y]=dis[now]+p[i].c;
				if(!tf[y]){
					f.push(y);
					tf[y]=true;
				}
			}
		}
	}
}

void Tp(int op){
	memset(first,0,sizeof(first));len=0;
	memset(in,0,sizeof(in));
	memset(cnt1,0,sizeof(cnt1));cnt1[op]=1;
	memset(insert,false,sizeof(insert));
	for(int i=1;i<=m;i++) 
		if(dis[p[i].x]+p[i].c==dis[p[i].y]) inss(p[i].x,p[i].y,p[i].c),insert[i]=true,in[p[i].y]++;
	tp[0]=0;
	for(int i=1;i<=n;i++) if(!in[i]) f.push(i);
	while(!f.empty()){
		int x=f.front();f.pop();
		tp[++tp[0]]=x;
		for(int i=first[x];i!=0;i=s[i].next){
			int y=s[i].y;
			in[y]--;(cnt1[y]+=cnt1[x])%=mod;
			if(!in[y]) f.push(y);
		}
	}
	for(int i=1;i<=n;i++) cnt2[i]=1;
	for(int i=tp[0];i>=1;i--)
		for(int j=first[tp[i]];j!=0;j=s[j].next)
			(cnt2[tp[i]]+=cnt2[s[j].y])%=mod;
}

void solve(int x){
	SPFA(x);Tp(x);
	for(int i=1;i<=m;i++)
		if(insert[i]) (ans[i]+=cnt1[p[i].x]*cnt2[p[i].y]%mod)%=mod;
}

int main(){
	scanf("%d %d",&n,&m);
	int x,y,c;
	for(int i=1;i<=m;i++){
		scanf("%d %d %d",&x,&y,&c);
		ins(x,y,c);
	}
	for(int i=1;i<=n;i++) solve(i);
	for(int i=1;i<=m;i++) printf("%lld\n",ans[i]);
}

猜你喜欢

转载自blog.csdn.net/Deep_Kevin/article/details/83867711