FJUT3692-Hang的数学题

题目就是要求 ac == bd 的方案数,其中 a,b,c,d 1≤a,b,c,dn

题目链接:http://www.fjutacm.com/Problem.jsp?pid=3692

周赛的题,原题是牛客网的https://ac.nowcoder.com/acm/problem/21578,加大了点难度


先稍微分析一下数据范围

30数据,n比较小,所以直接暴力就行了,注意一下快速幂或者for循环求a^c次方会炸int,用4个for暴力轻松30

70数据,这个n是1e6,在算上t组,所以复杂度肯定是小于O(n)的,不然1e9就已经T了,其实对于70和90的数据,复杂度只要是sqrt(n)在乘点不是很大的常数,都是可以过的

100数据,这个t是2000,假设我留4e8来跑最后一组数据,除完省2e5,而n的范围是1e12,开个根号也要1e6,所以复杂度就不是sqrt的,比sqrt还要快的,那就是log和常数的,但是很明显,这题和素数有关,所以基本上不可能是O(1)的,所以理论上复杂度就是log或者log的几次方这样的

下面讲讲出题人做这道题的思路

我们先考虑比较简单的两种情况

  • 当a=b=1的时候,不论cd取什么值,等式都是成立的,这时候c有n种取法,d有n种,方案数就是n*n
  • 类似的,我们可以考虑a=b!=1的情况,这时候要使得等式成立,必须要c=d,所以有cd选择有n种,ab选择有n-1种(去掉1的情况)所以就是n*(n-1)
  • 而难点就是ab不等的情况,这时候要怎么考虑
  • 下面的情况都是以a<b来考虑的,a=b的情况最开始考虑过了,而a>b的情况,直接就交换ab和cd,就可以化成a<b的情况,所以就是在a<b的情况下,直接乘2就行了
  • 首先,如果a^c=b^d,可以明确(x^y1)^c=(x^y2)^d,且c*y1=d*y2
  • 就是说ab之间一定有不为1的公约数存在,且a^c=b^d成立时,a^(kc)=b^(kd)也一定是成立的(数学上就是两边同时k次方),当而这个k的最大值,因为a<b,所以c>d,所以kc会更早趋近于n,所以对答案的贡献(也就是k的最大值)就是n/k,接着我们就是要找到所有的ab的情况
  • 我们可以用一个简单的for循环嵌套来遍历ab

扫描二维码关注公众号,回复: 6143096 查看本文章
  • 这样a=i^k,b=i^q,并且保证了ab之间有公约数i
  • 接着我们要求一下上面讲的k的最大值,对于这种写法,我们就可以看成c=N*q,d=N*k(N为[1,n/q]的任意整数),这时候要求k的最大值(也就是对答案的贡献),根据上面的结论,kmax=n/q
  • 我们可以发现,对于上面的for,k<q,所以对答案的贡献都是n/q
  • 但是要考虑一种情况,比如当i=2,k=2,q=4,这时候的a=4,b=16,当你计算这个之后,你会在i=4,k=1,q=2的时候重复计算一遍,为了避免这种情况,我们要对kq求gcd,只有当他们互质的时候才计算答案,这样能保证,在i比较小的时候不计算,在i大的时候才计算,对于上面例子就是i=2,k=2,q=4不计算(因为gcd(k,q)=2),i=4,k=1,q=2计算(因为gcd(k,q)=1)
  • 这时候我们就得到了一份暴力的代码,那个*2是因为我们算了a<b,还要考虑a>b的情况,直接乘2
  • ps:不怎么写博客,不知道为啥上下有两行,等以后懂了再填坑
 
 

  ll sum=0;

sum=(ll)n*n%mod;
    sum=(sum+((ll)n-1)*n)%mod;
    x=sqrt(n);
    for(int i=2; i<=x; i++)
    {
        q=1;
        for(ll j=i*i; j<=n; j*=i)
        {
            q++;
            for(int k=0; k<q; k++)
            {
                int num=gcd(k,q);
                if(num==1)
                    sum=(sum+(n/q)*2)%mod;
            }
        }
    }
    printf("%lld\n",sum);
  • 稍微计算一下,复杂度是sqrt(n)*(一个小于60的常数)^2(因为第二个for是log级别的)
  • 明显不符合100分的要求,这个代码是可以过70~90的数据(为什么是范围呢,这要看你写的时候常数有多大)
  • 我们开始优化
  • 可以发现,q是一个1~60的值(因为log2(n)最大值才50左右,就是2^50是大于1e12,具体多少我没去算),对于gcd(k,q)我们可以预处理一下,用一个数组gc[i]表示1~i-1有多少个数和i互质
  • 转换成代码就是
for(i=2; i<64; i++)
        if(gc[i]==i)
            for(j=i; j<64; j+=i)
                gc[j]=gc[j]/i*(i-1);

这样下面求的公式就化简成了

x=sqrt(n);
        for(int i=2; i<=x; i++)
        {
            q=1;
            for(ll j=i*i; j<=n; j*=i)
            {
                q++;
                sum=(sum+(n/q)*2*gc[q])%mod;
            }
        }
        printf("%lld\n",sum);
  • 求gc的函数可以用欧拉函数来求(求1~i-1有多少个数和i互质)
  • 这是欧拉函数预处理的写法
for(i=2; i<63; i++)
        gc[i]=i;
    for(i=2; i<63; i++)
    {
        if(gc[i]==i)
            for(int j=i; j<63; j+=i)
                gc[j]=gc[j]/i*(i-1);
        gc[i]<<=1;
    }
  • 经过这个化简,我们就少了一个log复杂度,但是总的还是sqrt级别的,所以还是会t,但是这个代码好像能轻松过90?
  • 再想想化简方式,可以发现第二个for,每次计算都是q=2开始,到log那啥,n/q和gc都是一样的,我们同样可以处理个前缀和,我没有每次同余是因为对于90的数据,这个x是在mod范围内,*个sqrt,也还在ll范围内,所以就没有取模,就最后取模一次就行了
len=sqrt(n);
        for(int i=2;i<34;i++)
            x[i]=(x[i-1]+(n/i)*2*gc[i])%mod;
        for(register int i=2; i<=len; i++)
        {
            k=1;
            for(ll j=i*i;j<=n;j*=i)
                k++;
            sum+=x[k];
        }
        printf("%lld\n",sum%mod);
  • 但是还是治标不治本,怎么把sqrt化成log呢?
  • 由上面这个代码,你可以发现,每次对答案有影响的值主要取决于logi(n)的值,假设这个值为k,就是求i^k<=n的最大k值,而这个k的范围只有1到log2(n)
  • 很明显,k是一个单调不减的值,且其中有大量的重复值
  • 我假设Sqrt(n)3的含义是3次根号n,我懒得打数学符号了(其实是不会打)
  • 观察下面for的q,可以算出有多少个q都是满足条件的
  • 先明确一下q的含义,q表示i^q
  • i^q<=n  ->  i<=sqrt(n)q,  也就是说,有sqrt(n)q 个i满足 i^k<=n
我们要去掉q=1的情况,在求得sqrt(n)q的情况下,只需要遍历q的所有可能值,也就是log2(n),每次计算sum+=1LL*(sqrt(n)q)*((n/i)%mod)*gc[i]
这样得到的就是我们要的答案
求sqrt(n)q可以使用pow函数,因为pow会丢精度,所以补个0.5就OK了(不用怀疑数据的准确性,我用上面不补精度的代码尝试过了)
所以最后的代码就是
len=1;
        for(i=2;i<=n;i<<=1)
            len++;
        for(i=2;i<=len;i++)
        {
            temp=pow(n+0.5,1.0/i)-1;
            sum=(sum+temp*(n/i)*gc[i])%mod;
        }
        printf("%lld\n",sum);
因为n范围的关系,所以取模操作就得小心
附上完整核心代码
int t,i;
    for(i=2; i<63; i++)
        gc[i]=i;
    for(i=2; i<63; i++)
    {
        if(gc[i]==i)
            for(int j=i; j<63; j+=i)
                gc[j]=gc[j]/i*(i-1);
        gc[i]<<=1;
    }
    scanf("%d",&t);
    long long sum,n,j,temp;
    while(t--)
    {
        scanf("%lld",&n);
        sum=((2LL*n%mod*(n%mod)-n)%mod+mod)%mod;
        i=1;
        for(j=2;j<=n;j<<=1)
        {
            i++;
            temp=pow(n+0.5,1.0/i)-1;
            sum+=1LL*temp*((n/i)%mod)*gc[i];
            sum%=mod;
        }
        printf("%lld\n",sum%mod);
    }

猜你喜欢

转载自www.cnblogs.com/rainH/p/10822746.html