P3172 [CQOI2015]选数 容斥+记忆化搜索

P3172 [CQOI2015]选数

标签

  • 容斥
  • 记忆化搜索

前言

  • 很好的题~可以反演后杜教筛,也可以推式子然后dp!!

简明题意

  • 给定\(n,k,L,R\),需要你求出,求从区间\([L,R]\)中选出\(n\)个数且他们的\(gcd=k\)的方案数。(可以重复选数)

思路

  • 我们假设一组样例,\(n=2,k=3,L=2,R=10\)。于是,我们需要从\([2,10]\)中选出两个数,使得他们的\(gcd=2\)。一个很明显的东西,就是无论我们怎么选,选的这些数一定是\(k\)的倍数。
  • 我们发现,我们假设的样例中,k=3的倍数有3,6,9,假设个数为\(x\),这里\(x=3\)。于是我们列出所有的选择方案,一共有\(x^n=3^2=9\)种方案,分别是:(3,3),(3,6),(3,9),(6,3),(6,6),(6,9),(9,3),(9,6),(9,9)。显然,有一些的\(gcd\)显然是3,有一些的是6,有一些甚至是9。
  • 现在设\(f(i)\)是对于给定\(n,k,L,R\),从中选出\(n\)个数且\(gcd=i\)的方案数。还是根据上面的样例,\(f(2)=3^2-f(6)-f(9)\)(琢磨一下这个式子怎么来的)。所以令\(x=[L,R]\)\(i\)的倍数的个数,根据上面的样例,我们很容易知道
    \[f(i)=x^n-f(2*i)-f(3*i)-...-f(m*i)(m*i<=R)\]

    这里讲一下x的求法。\(x=\frac Ri-\frac{L-1}i\).这个式子怎么来的可以自己举例就知道了

  • 这是一个递归式,显然,当\(k=1,R=1e9\)时,递推一定T。得优化。优化的突破口是这句话:\(H-L<=10^5\)。然后我们又有一个重要的性质:\([l,r]\)中选若干个不相同的数,他们的\(gcd\)不超过\(R-L\)。知道了这个性质,我们的\(f(x)\)就可以不用从\(f(2*i)\)一直考虑到\(f(m*i)(m*i<=R)\),因为\(m*i\)根本不可能是\(gcd\),从而只用从\(f(2*i)\)考虑到\(f(m*i)(m*i<= R-L)\),但是需要稍微改变一下。
  • 回忆\(f(i)\)的定义,它包含了选择相同数的情况。那么我们只计算选择不同数的情况,这样就能用上面说的\(gcd\)的性质,只需要最后减去选择相同数的方案数。但是很麻烦,设相同的方案数是p,当n=1,p显然是0,当n>=2,我们需要考虑[L,R]区间中k的倍数,然后组合数学计算。很麻烦。这里有一种简化方法:
  • 就是直接把L,R除以k,然后就变成在新的区间里找gcd==1的方案数了。这样之后,重复元素的方案数就很好算了,因为k=1,与1互质的数只有1,所以重复元素只能是1,只要新的L=1且要选2个数以上,那么方案数就+1,就这么简单。

注意事项

  • 减法操作直接取模会出问题,应该先加上模数再取模

总结

  • \([l,r]\)中选若干个不相同的数,他们的\(gcd\)不超过\(R-L\)

AC代码

#include <cstdio>
#include <unordered_map>
using namespace std;

const int mod = 1000000007;

int ksm(int a, int b) {
    int base = a, ans = 1;
    while (b) {
        if (b & 1)
            ans = 1ll * ans * base % mod;
        base = 1ll * base * base % mod;
        b >>= 1;
    }
    return ans;
}

int n, k, l, r;
unordered_map<int, int> rec;

int f(int k) {
    if (rec[k] != 0)
        return rec[k];
    int x = r / k - (l - 1) / k;
    int ans = ksm(x, n) - x;

    for (int i = 2; i * k <= r - l; i++) ans = (ans - f(i * k) + mod) % mod;

    return rec[k] = ans;
}

void solve() {
    scanf("%d%d%d%d", &n, &k, &l, &r);
    l = (l % k == 0 ? l / k : l / k + 1);
    r /= k;
    k = 1;

    printf("%d", f(k) + (l == 1 && n >= 2));
}

int main() {
    solve();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/danzh/p/11299393.html