SNOI省选模拟赛Round4 T3 回家home 矩阵+容斥

题目大意:n个点m条边,学校在1号店,家在2号点,有k个点(不是家和学校)必须要到达,求从学校到家经过路径数在[l,r]之间的方案数。

n<=30,m<=300,k<=4,l,r<=1e9。


题解:我今天才知道一个图的邻接矩阵的k次方是从i到j走k步的方案数...

知道了这个,先考虑没有k个限制的情况,我们可以利用前缀和来求出[1,l-1]和[1,r]的方案总和,相减即可。

前缀和的处理:在做快速幂时,可以先预处理出2的i次方的转移矩阵,也就是要作k次方的矩阵的1次方,2次方,4次方,8次方……然后再开一个新矩阵,转移与原矩阵一样,只是每次转移是要加上上一次的结果,这样前缀和就处理出来了。

接着考虑有k个限制条件的情况,鉴于k很小,我们可以枚举哪些点去哪些点不去,不去的点将邻接矩阵内改为0,做2^4次方次矩阵快速幂即可。

统计答案可以用容斥原理,减去删去奇数个点的加上删去偶数个点的就是选了所有点的答案了。

代码:

#include<bits/stdc++.h>
#define mod 1000000009
using namespace std;
typedef long long LL;
int read()
{
	char c;int sum=0,f=1;c=getchar();
	while(c<'0' || c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0' && c<='9'){sum=sum*10+c-'0';c=getchar();}
	return sum*f;
}
int n,m,k,l,r,ans,can[35];
int p[5];
struct matrix{
	LL s[35][35];
	matrix(){memset(s,0,sizeof(s));}
	void setone()
	{
		for(int i=1;i<=n;i++)
		s[i][i]=1; 
	}
}A;
matrix operator*(matrix x,matrix y)
{
	matrix ans;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=n;j++)
	for(int k=1;k<=n;k++)
	ans.s[i][j]=(ans.s[i][j]+x.s[i][k]*y.s[k][j]%mod)%mod;
	return ans;
}
matrix operator*=(matrix &x,matrix y)
{
	x=x*y;
	return x;
}
matrix operator+=(matrix &x,matrix y)
{
	for(int i=1;i<=n;i++)
	for(int j=1;j<=n;j++)
	x.s[i][j]=(x.s[i][j]+y.s[i][j])%mod;
	return x;
}
matrix Pow[35],tpow[35];
matrix ksm(matrix A,int k)
{
	Pow[0]=A;
	for(int i=1;i<=30;i++)
	{
		Pow[i]=Pow[i-1];
		Pow[i]*=Pow[i];
	}
	tpow[0]=A;
	for(int i=1;i<=30;i++)
	{
		tpow[i]=tpow[i-1];
		tpow[i]*=Pow[i];
		tpow[i]+=tpow[i-1];
	}
	matrix ret,one;
	one.setone();
	for(int i=30;i>=0;i--)
	{
		if((k>>i)&1)
		{
			tpow[i]*=one;
			ret+=tpow[i];
			one*=Pow[i]; 
		}
	}
	return ret;
}
int cal(int x)
{
	int ans=0;
	matrix tmp;
	for(int state=0;state<(1<<k);state++)
	{
		tmp=A;
		for(int i=1;i<=n;i++)
		{
			can[i]=1;
			for(int j=0;j<k;j++)
			if(state&(1<<j))
			if(i==p[j]) can[j]=0;
		}
		for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
		if((!can[i]) || (!can[j]))
		tmp.s[i][j]=0;
		tmp=ksm(tmp,x);
		if(__builtin_popcount(state)&1)
			ans=(ans-tmp.s[1][2])%mod;
		else
			ans=(ans+tmp.s[1][2])%mod;
	}
	return ans;
}
int main()
{
	n=read();m=read();k=read();l=read();r=read();
	for(int i=1;i<=m;i++)
	{
		int u=read(),v=read();
		A.s[u][v]=1;
		A.s[v][u]=1;
	}
	for(int i=0;i<k;i++)
	p[i]=read();
	ans=cal(r);
	if(l>=1)
	ans=((ans-cal(l-1))%mod+mod)%mod;
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39791208/article/details/79101031