day22T3改错记

题目描述

有一个环上有\(n\)颗珠子,你要把其中\(m\)颗染成金色,要求连续金色段的长度不超过\(k\),两种方案如果能通过旋转变成一样的,认为这两种方案本质相同

求本质不同的满足条件的方案数,答案对\(998244353\)取模

多组数据,数据组数为\(T\)

\(T \le 5, 0 \le k \le m \le n, 1 \le n \le 1e6\)

解析

由循环同构想到先套个\(Burnside\)

置换就是旋转\(1\)~\(n\)个位置,显然第\(i\)个置换成立的条件是\(\frac{n}{\gcd(i, n)} | m\),否则不可能做到正好把某些环染色

那么就有:
\[ ans = \frac{1}{n} \sum_{i = 1}^{n} [\gcd(i, n) | m] \cdot f(\gcd(i, n), \frac{m}{\frac{n}{\gcd(i, n)}}) \]
其中\(f(n, m)\)表示长度为\(n\)的序列,选\(m\)个位置染色,首尾相接后连续一段被染色的长度不超过\(k\)的方案数

然后大力推式子:
\[ \begin{align} ans & = \frac{1}{n} \sum_{i = 1}^{n} [\gcd(i, n) | m] \cdot f(\gcd(i, n), \frac{m}{\frac{n}{\gcd(i, n)}}) \\ & = \frac{1}{n} \sum_{d | n} [\gcd(d, n) == 1][\frac{n}{d} | m] f(d, \frac{md}{n}) \\ & = \frac{1}{n} \sum_{d | n} [\frac{n}{d} | m]f(d, \frac{md}{n}) \phi(\frac{n}{d}) \\ & = \frac{1}{n} \sum_{d | \gcd(n, m)} f(\frac{n}{d}, \frac{m}{d}) \phi(d) \end{align} \]
然后考虑怎么求\(f(n, m)\)

等价于把\(m\)个染色球放进首尾和\(n - m - 1\)个空隙中

先枚举首尾一共放了\(i\)个球,分成首尾两截就有\(i + 1\)种方案,然后把剩下的\(m - i\)个球插入\(n - m - 1\)个空隙

发现连续一段不超过\(k\)的限制很烦,不妨先枚举至少哪些位置超过了,没有限制的就隔板法解决,容斥一下就好了:
\[ f(n, m) = \sum_{i = 0}^{\min(m, k)} (i + 1) \cdot \sum_{j = 0}^{n - m - 1} {(-1) ^ j} {n - m - 1 \choose j} {m - i - j(k + 1) + n - m - 2 \choose n - m - 2} \]
后面的部分不会枚举完(最后一个组合数上面变负数了就可以停了),这样求一次\(f(n, m)\)\(O(k \cdot \frac{n}{k}) = O(n)\)

总复杂度就是\(O(\sigma(\gcd(n, m)))\)的(\(\sigma(n)\)\(n\)的约数和)

代码

#include <cstdio>
#include <iostream>
#include <cstring>
#define MAXN 1000005

typedef long long LL;
const int mod = 998244353;

char gc();
int read();
void prework();
int comb(int, int);
int calc(int, int);
int qpower(int, int);

int fact[MAXN], ifact[MAXN], inv[MAXN], phi[MAXN];
int T, N, M, K;

inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { x += y; return x >= mod ? x - mod : x; }
inline int sub(int x, int y) { x -= y; return x < 0 ? x + mod : x; }
inline int mul(int x, int y) { return (LL)x * y % mod; }
inline int gcd(int x, int y) { while (y ^= x ^= y ^= x %= y); return x; }

int main() {
    freopen("gift.in", "r", stdin);
    freopen("gift.out", "w", stdout);

    prework();
    T = read();
    while (T--) {
        int ans = 0, d;
        N = read(), M = read(), K = read();
        if (M == 0) { puts("1"); continue; }
        else if (N == M) {
            printf("%d\n", (int)(M <= K));
            continue;
        }
        d = gcd(N, M);
        for (int i = 1; i * i <= d; ++i) {
            if (d % i) continue;
            inc(ans, mul(calc(N / i, M / i), phi[i]));
            if ((i * i) ^ d) {
                int k = d / i;
                inc(ans, mul(calc(N / k, M / k), phi[k]));
            }
        }
        printf("%d\n", mul(ans, inv[N]));
    }

    return 0;
}
inline char gc() {
    static char buf[1000000], *p1, *p2;
    if (p1 == p2) p1 = (p2 = buf) + fread(buf, 1, 1000000, stdin);
    return p1 == p2 ? EOF : *p2++;
}
inline int read() {
    int res = 0; char ch = gc();
    while (ch < '0' || ch > '9') ch = gc();
    while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + ch - '0', ch = gc();
    return res;
}
void prework() {
    static bool isn_prime[MAXN];
    static int prime[MAXN], tot;
    phi[1] = 1;
    for (int i = 2; i < MAXN; ++i) {
        if (!isn_prime[i]) prime[tot++] = i, phi[i] = i - 1;
        for (int j = 0; j < tot && i * prime[j] < MAXN; ++j) {
            isn_prime[i * prime[j]] = 1;
            if (i % prime[j]) phi[i * prime[j]] = phi[i] * (prime[j] - 1);
            else { phi[i * prime[j]] = phi[i] * prime[j]; break; }
        }
    }
    fact[0] = ifact[0] = fact[1] = ifact[1] = inv[1] = 1;
    for (int i = 2; i < MAXN; ++i) {
        fact[i] = mul(fact[i - 1], i);
        inv[i] = sub(0, mul(mod / i, inv[mod % i]));
        ifact[i] = mul(ifact[i - 1], inv[i]);
    }
}
int comb(int n, int m) {
    if (m > n || n < 0 || m < 0) return 0;
    return mul(fact[n], mul(ifact[m], ifact[n - m]));
}
int qpower(int x, int y) {
    int res = 1;
    while (y) {
        if (y & 1) res = mul(res, x);
        x = mul(x, x), y >>= 1;
    }
    return res;
}
int calc(int n, int m) {
    if (n == m) return m <= K;
    if (n - m == 1) return (m <= K) * n;
    int res = 0, top = std::min(m, K);
    for (int i = 0; i <= top; ++i) {
        int tmp = 0;
        for (int j = 0; j < n - m; ++j) {
            if (n - i - j * (K + 1) - 2 < 0) break;
            if (j & 1) dec(tmp, mul(comb(n - m - 1, j), comb(n - i - j * (K + 1) - 2, n - m - 2)));
            else inc(tmp, mul(comb(n - m - 1, j), comb(n - i - j * (K + 1) - 2, n - m - 2)));
        }
        inc(res, mul(tmp, i + 1));
    }
    return res;
}
//Rhein_E 100pts

猜你喜欢

转载自www.cnblogs.com/Rhein-E/p/10603659.html