NEUOJ 1343 Eat walnuts (容斥原理 + 逆元 + 唯一分解定理)

经典问题,求[1, n]中与m互素的数的和与平方的和
#include <iostream>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <map>
#include <queue>

using namespace std;
typedef long long ll;
vector<int> primes, factors;
const int maxn = 10000;
const int mod = 1000000007;
ll n, m;
bool vis[maxn + 5];

void init() {
    for(int i = 2; i <= maxn; ++i) {
        for(int j = i * i; j <= maxn; j += i) {
            vis[j] = true;
        }
    }
    for(int i = 2; i <= maxn; ++i)
        if(!vis[i])  {
            primes.push_back(i);
            //printf("prime = %d\n", i);
        }
}

void get_factors(ll m) {
    factors.clear();
    for(int i = 0; i < primes.size(); ++i) {
        int p = primes[i];
        if(p > m) break;
        if(m % p == 0) {
            factors.push_back(p);
            while(m % p == 0) m /= p;
        }
    }
    if(m > 1) factors.push_back(m);
    //for(int i = 0; i < factors.size(); ++i) printf("factor = %d\n", factors[i]);
}

ll qpow(ll a, ll b) {
    ll ret = 1;
    while(b) {
        if(b & 1) ret = ret * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ret;
}

ll inv6 = qpow(6, mod - 2);
ll inv2 = qpow(2, mod - 2);

ll sum1(ll n) {
    return n*(n+1)%mod*inv2%mod;
}

ll sum2(ll n) {
    return n*(n+1)%mod*(2*n+1)%mod*inv6%mod;
}

ll solve(ll n) {
    ll ret = (sum1(n) + sum2(n)) % mod;
    //printf("sum1 = %d\n", ret);
    int sz = factors.size();
    for(int i = 1; i < (1 << sz); ++i) {
        ll prud = 1;
        int bits = 0;
        for(int j = 0; j < sz; ++j) {
            if(i & (1 << j)) {
                bits++;
                prud *= factors[j];
            }
        }
        int cur = (prud*sum1(n/prud)%mod + prud*prud%mod*sum2(n/prud)%mod)%mod;
        //printf("prud = %d, cur = %d\n", prud, cur);
        if(bits & 1) ret = (ret - cur + mod) % mod;
        else ret = (ret + cur) % mod;
    }
    return ret;
}

int main() {
    init();
    while(cin >> n >> m) {
        get_factors(m);
        cout << solve(n) << endl;
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/chcnsn/article/details/80085370
今日推荐