2019.01.26 codeforces 1096G. Lucky Tickets(生成函数)

版权声明:随意转载哦......但还是请注明出处吧: https://blog.csdn.net/dreaming__ldx/article/details/86655781

传送门
题意简述:现在有一些号码由 0 0 ~ 9 9 中的某些数字组成(会给出),号码总长度为 n n ,问有多少个号码满足前 n 2 \frac n2 个数码的和等于后 n 2 \frac n2 个数码的和(保证 n n 是偶数),答案对 998244353 998244353 取模。


思路:
一道挺显然的生成函数+快速幂。
考虑到前 n 2 \frac n2 个数码和的生成函数和后 n 2 \frac n2 个数码和的生成函数是相同的,因此直接求出前 n 2 \frac n2 个数码和的生成函数,然后对于每一项的系数平方加起来即可。
代码:

#include<bits/stdc++.h>
#define ri register int
using namespace std;
const int mod=998244353,N=2e5+5;
typedef long long ll;
int n,k,tim,lim;
vector<int>A,B,pos;
inline void init(const int&up){
	lim=1,tim=0;
	while(lim<=up)lim<<=1,++tim;
	pos.resize(lim),A.resize(lim),B.resize(lim);
	for(ri i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
}
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a>=b?a-b:a-b+mod;}
inline int mul(int a,int b){return (ll)a*b%mod;}
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=mul(a,a))if(p&1)ret=mul(ret,a);return ret;}
inline void ntt(vector<int>&a,const int&type){
	for(ri i=0;i<lim;++i)if(i<pos[i])swap(a[i],a[pos[i]]);
	for(ri w,wn,typ=type==1?3:(mod+1)/3,mult=(mod-1)/2,mid=1;mid<lim;mid<<=1,mult>>=1){
		wn=ksm(typ,mult);
		for(ri j=0,len=mid<<1;j<lim;j+=len)for(ri w=1,a0,a1,k=0;k<mid;++k,w=mul(w,wn)){
			a0=a[j+k],a1=mul(w,a[j+k+mid]);
			a[j+k]=add(a0,a1),a[j+k+mid]=dec(a0,a1);
		}
	}
	if(type==-1)for(ri i=0,inv=ksm(lim,mod-2);i<lim;++i)a[i]=mul(a[i],inv);
}
struct poly{
	vector<int>a;
	poly(int k=0,int x=0){a.resize(k+1),a[k]=x;}
	inline int&operator[](const int&k){return a[k];}
	inline const int&operator[](const int&k)const{return a[k];}
	inline int deg()const{return a.size()-1;}
	inline poly extend(int k){poly ret=*this;return ret.a.resize(k+1),ret;}
	friend inline poly operator^(const poly&a,const int&k){
		init(a.deg()*k);
		for(ri i=0;i<=a.deg();++i)A[i]=B[i]=a[i];
		for(ri i=a.deg()+1;i<lim;++i)A[i]=B[i]=0;
		int p=k-1;
		ntt(A,1),ntt(B,1);
		while(p){
			if(p&1)for(ri i=0;i<lim;++i)B[i]=mul(B[i],A[i]);
			for(ri i=0;i<lim;++i)A[i]=mul(A[i],A[i]);
			p>>=1;
		}
		poly ret;
		return ntt(B,-1),ret.a=B,ret.extend(a.deg()*k);
	}
};
int main(){
	poly a(9);
	int mx=0;
	scanf("%d%d",&n,&k),n>>=1;
	for(ri i=1,v;i<=k;++i)scanf("%d",&v),a[v]=1,mx=max(mx,v);
	a=a.extend(mx);
	a=(a^n);
	int ans=0;
	for(ri i=0;i<=mx*n;++i)ans=add(ans,mul(a[i],a[i]));
	cout<<ans;
	return 0;
}

猜你喜欢

转载自blog.csdn.net/dreaming__ldx/article/details/86655781