题目描述
现在给出了一个简单无向加权图。你不满足于求出这个图的最小生成树,而希望知道这个图中有多少个不同的最小生成树。(如果两颗最小生成树中至少有一条边不同,则这两个最小生成树就是不同的)。由于不同的最小生成树可能很多,所以你只需要输出方案数对31011的模就可以了。
输入输出格式
输入格式:
第一行包含两个数,n和m,其中1<=n<=100; 1<=m<=1000; 表示该无向图的节点数和边数。每个节点用1~n的整数编号。
接下来的m行,每行包含两个整数:a, b, c,表示节点a, b之间的边的权值为c,其中1<=c<=1,000,000,000。
数据保证不会出现自回边和重边。注意:具有相同权值的边不会超过10条。
输出格式:
输出不同的最小生成树有多少个。你只需要输出数量对31011的模就可以了。
输入输出样例
输入样例#1:
4 6
1 2 1
1 3 1
1 4 1
2 3 2
2 4 1
3 4 1
输出样例#1:
8
说明
说明 1<=n<=100; 1<=m<=1000; 。
分析:
一个图不同的最小生成树有两个性质。
第一是所有最小生成树中相同权值的边使用了相同多次。我们考虑我们已经建好了一棵最小生成树,对于一条不在这棵树上的边
,保证在树上
到
的路径上的权值都小于等于这条边的权值,而且只有替换相同权值的边,新树才会是最小生成树。
第二是同一种权值的边连完后,连通块完全一样。也可以理解为相同权值的边无论怎样先后顺序如何,连接后的连通块完全一样。这个是很显然的。
假如我们假如了权值小于等于
的边,形成若干连通块。此时加入边权为
的边,将会连接一些连通块,把这些连通块看做点,连接相当于形成一棵树,使用矩阵树就可以。还有一种特殊情况,就是同一种边权的边连接形成的不是一个连通图。举个例子,比如说第一条边连接了
和
,第二条边连接
和
,此时如果直接建图做矩阵树det就是
。那么我们可以强行把他建成连通图,把每个连通块连成一个链,因为链上的边一定会被选,所以相当于每个连通块的树的个数的乘积。
代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define LL long long
const int maxn=107;
const int maxe=1007;
const LL mod=31011;
using namespace std;
int n,m,cnt;
LL a[maxn][maxn];
int b[maxn],p[maxn],f[maxn];
struct edge{
int x,y,w;
}g[maxe];
bool cmp(edge x,edge y)
{
return x.w<y.w;
}
int find(int x,int *p)
{
int y=x,root;
while (p[x]) x=p[x];
root=x;
x=y;
while (p[x])
{
y=p[x];
p[x]=root;
x=y;
}
return root;
}
void uni(int x,int y,int *p)
{
int u=find(x,p);
int v=find(y,p);
if (u==v) return;
p[u]=v;
}
LL det()
{
int n=cnt-1;
LL ans=1;
for (int i=1;i<=n;i++)
{
for (int j=i+1;j<=n;j++)
{
while (a[j][i])
{
LL rate=a[i][i]/a[j][i];
for (int k=i;k<=n;k++)
{
a[i][k]=(a[i][k]-rate*a[j][k]%mod+mod)%mod;
swap(a[i][k],a[j][k]);
}
ans=mod-ans;
}
}
ans=(ans*a[i][i])%mod;
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=m;i++) scanf("%d%d%d",&g[i].x,&g[i].y,&g[i].w);
sort(g+1,g+m+1,cmp);
LL ans=1,num=0;
for (int i=1,last;i<=m;i=last+1)
{
last=i;
memset(b,0,sizeof(b));
memset(a,0,sizeof(a));
memset(f,0,sizeof(f));
cnt=0;
while (g[i].w==g[last].w)
{
int x=g[last].x;
int y=g[last].y;
if (find(x,p)!=find(y,p))
{
int u=find(x,p),v=find(y,p);
if (b[u]==0) b[u]=++cnt;
if (b[v]==0) b[v]=++cnt;
a[b[u]][b[v]]=(a[b[u]][b[v]]-1+mod)%mod;
a[b[v]][b[u]]=(a[b[v]][b[u]]-1+mod)%mod;
a[b[u]][b[u]]++;
a[b[v]][b[v]]++;
uni(b[u],b[v],f);
}
last++;
}
last--;
for (int j=2;j<=cnt;j++)
{
if (find(j,f)!=find(j-1,f))
{
uni(j-1,j,f);
a[j-1][j]=(a[j-1][j]-1+mod)%mod;
a[j][j-1]=(a[j][j-1]-1+mod)%mod;
a[j-1][j-1]++;
a[j][j]++;
}
}
ans=(ans*det())%mod;
for (int j=i;j<=last;j++)
{
if (find(g[j].x,p)!=find(g[j].y,p))
{
num++;
uni(g[j].x,g[j].y,p);
}
}
}
if (num<n-1) printf("0");
else printf("%lld",ans);
}