Tsinsen D488 盒子

传送门
题目大意

对于一个 w × h 的矩形格子,从左下角向斜上方 45 射入一束光线,光线只会从格子的顶点射出。设 f ( w , h ) 表示 w × h 射入光线后光线的碰壁次数(不包含射出的那次),求:

i = 1 n i = 1 m f ( i , j )

对于 n 100 的数据

照题意模拟,时间复杂度 O ( n 4 ) ,得到 0 分。

对于 n 5000 的数据

打表发现 f i , j = i + j gcd ( i , j ) 2 (带入发现当 i = j f i , j = 0 ,符合题意),我们要求的是:

i = 1 n j = i + 1 n i + j gcd ( i , j ) 2

乘以 2 就是答案。时间复杂度 O ( n 2 log n ) ,由于除以了 2 ,并且这个过程常数极小,所以可以得到 20 分。


但是为什么是这个答案呢?我们考虑这么做:不把它看成反射,而是在平面内铺满这种格子,然后直射,则碰到的第一个顶点就是光线射出去的位置,与边缘碰到的次数就等于横着经过的格子数加上竖着经过的格子数,由于不算开始和结束的那次,因此还要减去 2 。考虑这个的答案是多少。由于我们是右上 45 发射的光线,因此一定有:

w x = h y

w h 分别表示格子的宽度和高度。显然上式的最小解是 x = h gcd ( w , h ) y = w gcd ( w , h ) ,证毕。

对于 n 10 7 的数据

现在我们要求:

i = 1 n j = 1 n i + j gcd ( i , j )

枚举 gcd
g = 1 n i = 1 n g j = 1 n g [ gcd ( i , j ) = 1 ] × ( i + j )

莫比乌斯反演:
g = 1 n i = 1 n g j = 1 n g d gcd ( i , j ) μ ( d ) × ( i + j )

m = n g ,抛开最外侧求和,我们求的是:
i = 1 m j = 1 m d gcd ( i , j ) μ ( d ) × ( i + j )

枚举 d
d = 1 m μ ( d ) d i = 1 m d j = 1 m d i + j

化简:
d = 1 m μ ( d ) d ( m d + 1 ) m d 2

时间复杂度 O ( n n )


把枚举 g 放回:

g = 1 n d = 1 n g μ ( d ) d ( n g d + 1 ) n g d 2

t = g d ,枚举 t
t = 1 n ( n t + 1 ) n t 2 d t μ ( d ) d

用数论分块 ,设:
f ( n ) = d n μ ( d ) d

即:
f ( n ) = ( μ i d ) 1

如果能求出 f ( n ) 的前缀和,问题就解决了。


显然 f ( n ) 是一个积性函数,可以用线性筛,时间复杂度 O ( n + n )

f ( p ) = 1 p

f ( p k ) = f ( p k 1 )

可以得到 60 分。

对于 n 10 9 的数据

考虑之前这个式子:

g = 1 n d = 1 n g μ ( d ) d ( n g d + 1 ) n g d 2

我们枚举 d
d = 1 n μ ( d ) d g = 1 n d ( n g d + 1 ) n g d 2

我们使用数论分块,左边使用杜教筛,右边再用一次数论分块。时间复杂度 O ( n 2 3 )


f ( n ) = μ ( n ) n 。设 g = f i d ,那么:

g ( n ) = d n μ ( d ) d n d = n [ n = 1 ]

S ( n ) 表示 g 的前缀和,枚举因数,有:
S ( n ) = d = 1 n d i = 1 n d f ( i )

F ( n ) 表示 f 的前缀和,代入上式,有:
S ( n ) = d = 1 n d F ( n d )

F ( n ) = 1 d = 2 n d F ( n d )

直接使用杜教筛即可。

参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
using LL = long long;
using ULL = unsigned long long;
using std::cin;
using std::cout;
using std::endl;
using INT_PUT = int;
INT_PUT readIn()
{
    INT_PUT a = 0;
    bool positive = true;
    char ch = getchar();
    while (!(std::isdigit(ch) || ch == '-')) ch = getchar();
    if (ch == '-')
    {
        positive = false;
        ch = getchar();
    }
    while (std::isdigit(ch))
    {
        (a *= 10) -= ch - '0';
        ch = getchar();
    }
    return positive ? -a : a;
}
void printOut(INT_PUT x)
{
    char buffer[20];
    int length = 0;
    if (x < 0) putchar('-');
    else x = -x;
    do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
    do putchar(buffer[--length]); while (length);
    putchar('\n');
}

const int mod = int(1e9) + 7;
int n;

int gcd(int a, int b)
{
    return b ? gcd(b, a % b) : a;
}

#define RunInstance(x) delete new x
struct brute1
{
    brute1()
    {
        register int ans = 0;
        for (register int i = 1; i <= n; i++)
            for (register int j = i + 1; j <= n; j++)
            {
                register int t;
                register int g = gcd(i, j);
                if (g == 1)
                    ans = (t = ans + i + j) >= mod ? t - mod : t;
                else
                    ans = (t = ans + (i + j) / g) >= mod ? t - mod : t;
            }
        ans = (ans - ((LL)n * n - n) % mod + mod) % mod;

        printOut((ans << 1) % mod);
    }
};
struct brute2
{
    static const int maxN = int(1e7) + 5;
    bool isntPrime[maxN]{};
    int prime[664580];
    int f[maxN];

    void init()
    {
        const int to = int(1e7);
        prime[0] = 0;
        isntPrime[1] = true;
        f[0] = 0;
        f[1] = 1;
        for (int i = 2; i <= to; i++)
        {
            if (!isntPrime[i])
            {
                prime[++prime[0]] = i;
                f[i] = 1 - i;
            }

            for (int j = 1, p = prime[j], s = i * p;
                j <= prime[0] && s <= to; j++, p = prime[j], s = i * p)
            {
                isntPrime[s] = true;
                if (i % p)
                {
                    f[s] = (LL)f[i] * f[p] % mod;
                }
                else
                {
                    f[s] = f[i];
                    break;
                }
            }
        }

        for (int i = 2; i <= to; i++)
            f[i] = ((LL)f[i - 1] + f[i] + mod) % mod;
    }

    brute2()
    {
        init();
        LL ans = (LL)-2 * n * n;
        ans = (ans % mod + mod) % mod;
        for (int i = 1, t; i <= n; i = t + 1)
        {
            t = n / (n / i);
            int Div = n / i;
            ans = (ans + (LL)(Div + 1) * Div % mod * Div % mod *
                (f[t] - f[i - 1]) % mod + mod) % mod;
        }
        printOut(ans);
    }
};
struct work
{
    static const int maxN = int(1e7) + 5;
    bool isntPrime[maxN]{};
    int prime[664580];
    int f[maxN];

    void init()
    {
        const int to = int(1e7);
        prime[0] = 0;
        isntPrime[1] = true;
        f[0] = 0;
        f[1] = 1;
        for (int i = 2; i <= to; i++)
        {
            if (!isntPrime[i])
            {
                prime[++prime[0]] = i;
                f[i] = -i;
            }

            for (int j = 1, p = prime[j], s = i * p;
                j <= prime[0] && s <= to; j++, p = prime[j], s = i * p)
            {
                isntPrime[s] = true;
                if (i % p)
                {
                    f[s] = (LL)f[i] * f[p] % mod;
                }
                else
                {
                    f[s] = 0;
                    break;
                }
            }
        }

        for (int i = 2; i <= to; i++)
            f[i] = ((LL)f[i - 1] + f[i] + mod) % mod;
    }
    struct HashTable
    {
        struct Node
        {
            int key;
            int val;
            int next;
            Node() = default;
            Node(int key) : key(key), val(), next(-1) {}
        };
        std::vector<Node> nodes;
        static const int size = int(1e6) + 7;
        int head[size];
        HashTable() { std::memset(head, -1, sizeof(head)); }
        int query(int key)
        {
            int cnt = head[key % size];
            while (~cnt)
            {
                if (nodes[cnt].key == key)
                    return nodes[cnt].val;
                cnt = nodes[cnt].next;
            }
            return -1;
        }
        void insert(int key, int val)
        {
            int cnt = head[key % size];
            if (~cnt)
            {
                while (~nodes[cnt].next)
                    cnt = nodes[cnt].next;
                nodes[cnt].next = nodes.size();
                nodes.push_back(Node(key));
                nodes.back().val = val;
            }
            else
            {
                head[key % size] = nodes.size();
                nodes.push_back(Node(key));
                nodes.back().val = val;
            }
        }
    } table;
    static inline int calc(int from, int to)
    {
        int x = from + to;
        int y = to - from + 1;
        if (!(x & 1)) x >>= 1;
        if (!(y & 1)) y >>= 1;
        return (LL)x * y % mod;
    }
    int F(int x)
    {
        if (x <= int(1e7))
            return f[x];
        int ret = table.query(x);
        if (~ret) return ret;
        ret = 1;
        for (int i = 2, t; i <= x; i = t + 1)
        {
            t = x / (x / i);
            ret = (ret - (LL)F(x / i) * calc(i, t)) % mod;
        }
        ret = (ret + mod) % mod;
        table.insert(x, ret);
        return ret;
    }
    int calcRight(int x)
    {
        int ret = 0;
        for (int i = 1, t; i <= x; i = t + 1)
        {
            t = x / (x / i);
            int val = x / i;
            ret = (ret + (LL)(t - i + 1) * (val + 1) % mod * val % mod * val) % mod;
        }
        return ret;
    }
    work()
    {
        init();
        LL ans = (LL)-2 * n * n;
        ans = (ans % mod + mod) % mod;
        for (int i = 1, t; i <= n; i = t + 1)
        {
            t = n / (n / i);
            int left = (F(t) - F(i - 1) + mod) % mod;
            int right = calcRight(n / i);
            ans = (ans + (LL)left * right) % mod;
        }
        printOut(ans);
    }
};

void run()
{
    n = readIn();

    RunInstance(work);
}

int main()
{
#ifndef LOCAL
    freopen("box.in", "r", stdin);
    freopen("box.out", "w", stdout);
#endif
    run();
    return 0;
}
总结

有时候化简式子有不止一种方法,走不通时看看有没有别的路。

猜你喜欢

转载自blog.csdn.net/lycheng1215/article/details/80816690