51nod 1690 区间求和2

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/baidu_36797646/article/details/83757855

题解:

一开始考虑的是对于每个 a i a_i 有哪些 a j a_j 与它相乘,但是这样做不了。
正解是考虑每对 ( a i , a j ) (a_i,a_j) 的贡献,然后用FFT优化。
首先直接把长度为 2 2 的给算了,然后剩下的都是奇质数长度。
预处理 s i s_i 表示 1 1 - i i 有多少个奇质数。
取模问题,直接long double,最后取模。
分两种情况:
1、 i + j < = n + 1 i+j<=n+1
考虑 a i × a j a_i\times a_j 有多少个,那也就是有多少质数长度的区间中心在 i , j i,j 的中点,稍微写一写式子就发现贡献是 a i × a j × ( s i + j 1 s i j ) a_i\times a_j\times(s_{i+j-1}-s_{|i-j|}) ,显然可以FFT。
2、 i + j > n i+j>n
写一写式子,发现贡献是 a i × a j × ( s 2 × n ( i + j ) + 1 s i j ) a_i\times a_j\times(s_{2\times n-(i+j)+1}-s{|i-j|})
然后愉快的发现只用 2 2 次FFT,于是本题就解决了。

代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define LD long double
#define pa pair<int,int>
const double pi=acos(-1.0);
const int Maxn=100010;
const int inf=2147483647;
const LL mod=1000000007LL;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*f;
}
struct C
{
	LD x,y;
	C(LD _x=0.0,LD _y=0.0){x=_x,y=_y;}
}a[Maxn<<3],b[Maxn<<3];
C operator + (C a,C b){return C(a.x+b.x,a.y+b.y);}
C operator - (C a,C b){return C(a.x-b.x,a.y-b.y);}
C operator * (C a,C b){return C(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int bin[Maxn<<3];
void fft(C *a,int n,int o)
{
	for(int i=0;i<n;i++)if(i<bin[i])swap(a[i],a[bin[i]]);
	for(int i=1;i<n;i<<=1)
	{
		C wn=C(cos(pi/i),sin(pi/i)*o);
		for(int j=0;j<n;j+=(i<<1))
		{
			C w=C(1,0);
			for(int k=0;k<i;k++)
			{
				C t=a[i+j+k]*w;w=w*wn;
				a[i+j+k]=a[j+k]-t;
				a[j+k]=a[j+k]+t;
			}
		}
	}
}
int n,A[Maxn];LL ans=0;
int prime[10000],len=0,s[Maxn];bool mark[Maxn];
void pre()
{
	memset(mark,false,sizeof(mark));
	for(int i=2;i<=n;i++)
	{
		if(!mark[i])prime[++len]=i,s[i]=1;
		for(int j=1;j<=len&&prime[j]*i<=n;j++)
		{
			mark[prime[j]*i]=true;
			if(i%prime[j]==0)break;	
		}
	}
	s[2]=0;
	for(int i=3;i<=n;i++)s[i]+=s[i-1];
}
int main()
{
	n=read();
	pre();
	for(int i=1;i<=n;i++)A[i]=read();
	for(int i=2;i<=n;i++)ans=(ans+(LL)A[i-1]*A[i]%mod*2LL%mod)%mod;
	a[0]=b[0]=C(0,0);
	for(int i=1;i<=n;i++)a[i]=b[i]=C(A[i],0);
	int t=1;while(t<=n)t<<=1;t<<=1;
	bin[0]=0;
	for(int i=1;i<=t;i++)bin[i]=((bin[i>>1]>>1)|((i&1)*(t>>1)));
	fft(a,t,1),fft(b,t,1);
	for(int i=0;i<=t;i++)a[i]=a[i]*b[i];
	fft(a,t,-1);
	for(int i=1;i<=(n<<1);i++)
	if(!(i&1))
	{
		LL tmp=(LL)(a[i].x/(double)t+0.5);tmp%=mod;
		if(i<=n+1)ans=(ans+(LL)tmp*s[i-1]%mod)%mod;
		else ans=(ans+(LL)tmp*s[(n<<1)-i+1]%mod)%mod;
	}
	memset(a,0,sizeof(a));memset(b,0,sizeof(b));
	a[0]=b[0]=C(0,0);
	for(int i=1;i<=n;i++)a[i]=C(A[i],0),b[i]=C(A[n-i+1],0);
	fft(a,t,1),fft(b,t,1);
	for(int i=0;i<=t;i++)a[i]=a[i]*b[i];
	fft(a,t,-1);
	for(int i=1;i<=(n<<1);i++)
	if(!(abs(i-n-1)&1))
	{
		LL tmp=(LL)(a[i].x/(double)t+0.5);tmp%=mod;
		int tt=abs(i-n-1);
		ans=(ans-(LL)tmp*s[tt]%mod+mod)%mod;
	}
	printf("%lld",ans);
}

猜你喜欢

转载自blog.csdn.net/baidu_36797646/article/details/83757855