# 题目描述

1    2    3    4    5
2    2    6    4    10
3    6    3    12    15
4    4    12    4    20

# 题解

$\sum^{n}_{i=1}\sum^{m}_{j=1}lcm(i,j)$

$\sum^{n}_{i=1}\sum^{m}_{j=1}\frac{i\times j}{gcd(i,j)}$

$\sum^{n}{i=1} \sum{m}{j=1} \sum_{d|i,d|j,gcd(\frac{i}{d},\frac{j}{d})=1} \frac{i\times j}{d}$

$\sum^{n}_{d=1} \times \sum^{\lfloor\frac{n}{d}\rfloor}_{i=1} \sum^{\lfloor\frac{m}{d}\rfloor}_{j=1}[gcd(i,j)=1] \times i \times j$

$calc(n,m) = \sum^{n}_{i=1} \sum^{m}_{j=1}[gcd(i,j) = 1] \times i \times j$

$\sum^{n}_{d=1} \sum^{n}_{d|i} \sum^{m}_{d|j} \mu (d) \times i \times j$
$$i=i' \times d, j = j' \times d$$，代入原式：
$\sum^{n}_{d=1} \mu(d) \times d ^ 2 \times \sum^{\lfloor \frac{n}{d}\rfloor}_{i=1} \sum^{\lfloor \frac{m}{d}\rfloor}_{j=1} \times i \times j$

$calc2(n,m)=\sum^{n}_{i=1} \sum^{m}_{j=1} \times i \times j = \frac{n \times (n + 1)}{2} \times \frac{m\times (m+1)}{2}$

# ac代码

#include <bits/stdc++.h>
#define ll long long
#define ms(a, b) memset(a, b, sizeof(a))
#define inf 0x3f3f3f3f
#define mod 20101009
#define N 10000005
using namespace std;
template <typename T>
x = 0; T fl = 1;
char ch = 0;
while (ch < '0' || ch > '9') {
if (ch == '-') fl = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
x *= fl;
}
int n, m, prime_tot;
bool vis[N];
int prime[N], mu[N], sum[N];
void get_mu(int MAXN) {
mu[1] = 1;
prime_tot = 0;
for (int i = 2; i <= MAXN; i ++) {
if (!vis[i]) {
prime[++ prime_tot] = i;
mu[i] = -1;
}
for (int j = 1; j <= prime_tot && prime[j] * i <= MAXN; j ++) {
vis[prime[j] * i] = 1;
if (i % prime[j] == 0) break;
else mu[prime[j] * i] = -mu[i];
}
}
for (int i = 1; i <= MAXN; i ++) {
sum[i] = (sum[i - 1] + 1ll * i * i % mod * (mu[i] + mod)) % mod;
}
}
int calc2(int x, int y) {
return (1ll * x * (x + 1) / 2 % mod) * (1ll * y * (y + 1) / 2 % mod) % mod;
}
int calc(int x, int y) {
int res = 0;
for (int l = 1, r; l <= min(x, y); l = r + 1) {
r = min(x / (x / l), y / (y / l));
res = (res + 1ll * (sum[r] - sum[l - 1] + mod) * calc2(x / l, y / l) % mod) % mod;
}
return res;
}
int main() {
get_mu(min(n, m) + 1);
int ans = 0;
for (int l = 1, r; l <= min(n, m); l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans = (ans + 1ll * (r - l + 1) * (r + l) / 2 % mod * calc(n / l, m / l) % mod) % mod;
}
printf("%d\n", ans);
return 0;
}


0条评论