2018icpc沈阳网络预选赛 G题(容斥)

题目:https://nanti.jisuanke.com/t/31448

题意:打表发现a[i]=i*(i+1),求\sum _{1}^{n}a[i]\, \, \, \, \, \, \, \, \, \, gcd(i,m)=1

思路:正着求不好求,我们可以求 总和减去与m不互质的数的贡献 

sum=\sum_{1}^{n} (i^2+i)=\sum_{1}^{n}i^2+\sum_{1}^{n}i=\frac{n*(n+1)*(2*n+1)}{6}+\frac{n*(n+1)}{2}

与m不互质的数我们可以分解m的质因子,对于当前的因子k,我们会有k,2k,3k,4k,5k,6k...\frac{n}{k}*k这些数与m不互质

对于因子k的贡献就为k*(k+1)+2k*(2k+1)+....+t*k+(t*k+1)\, \, \, \, \, t=\frac{n}{k}

该式就等于k^2*\frac{t*(t+1)*(2t+1)}{6}+k*\frac{t*(t+1)}{2}

然后因子个数奇数个加偶数个减容斥一下就可以了。

#include <iostream>
#include <string>
#include <algorithm>
#include <map>
#include <stdio.h>
#define LL long long
#define ll LL
#define mem(arry,value) memset(arry,value,sizeof(arry))
using namespace std;
const int mod=1e9+7;
int fac[20],tot;
void deal(int m)///分解质因数
{
    int tmp=m;
    for(int i=2;i*i<=m;i++)
    {
        if(tmp%i==0)
        {
            fac[tot++]=i;
            while(tmp%i==0)
                tmp/=i;
        }
    }
    if(tmp>1)
        fac[tot++]=tmp;
    return ;
}
LL fastpow(LL a, LL b, LL mod)
{
    LL ret = 1;
    while(b)
    {
        if(b&1)
            ret=(ret * a)%mod;
        a=(a*a)%mod;
        b>>= 1;
    }
    return ret;
}

int main()
{
    int n,m;
    while(~scanf("%lld%lld",&n,&m))
    {
        tot=0;
        deal(m);
        LL sum=0;
        for(int i=1;i<(1<<tot);i++)///容斥
        {
            LL cnt=0,tmp=1;
            for(int j=0;j<tot;j++)
            {
                if(i&(1<<j))
                {
                    tmp*=fac[j];
                    cnt++;
                }
            }
            if(cnt%2==0)
            {
                LL t=n/tmp;
                sum-=tmp*tmp%mod*t%mod*(t+1)%mod*(2*t+1)%mod*fastpow(6,mod-2,mod);
                sum-=tmp*t%mod*(t+1)%mod*fastpow(2,mod-2,mod);
            }
            else
            {
                LL t=n/tmp;
                sum+=tmp*tmp%mod*t%mod*(t+1)%mod*(2*t+1)%mod*fastpow(6,mod-2,mod);
                sum+=tmp*t%mod*(t+1)%mod*fastpow(2,mod-2,mod);
            }
        }
        LL ans=n*(n+1)%mod*(2*n+1)%mod*fastpow(6,mod-2,mod);
        ans+=n*(n+1)%mod*fastpow(2,mod-2,mod);
        printf("%lld\n",((ans-sum)%mod+mod)%mod);
    }
}

猜你喜欢

转载自blog.csdn.net/prometheus_97/article/details/82532836