CF 997C Sky Full of Stars

传送门
题目大意

有一个 n × n ( n 10 6 ) 的正方形网格。用三种颜色对每个格子染色,求有多少种染色方案使得至少一行或一列是同一种颜色。答案对 998244353 取模。

思路

正难则反,考虑求不存在一行或一列同色的方案数。结果反着也难……

正着做,考虑容斥。这里有行和列两个参数,如何容斥呢?我们定义 A i 表示第 i 行的颜色都相同的染色方案,用 B i 表示第 j 行的颜色都相同的染色方案,那么我们要求的是:

| A 1 A 2 A n B 1 B 2 B n |

把上面这个式子具体地写出来,你就知道该如何容斥了:设至少有 i 行的颜色一样,至少有 j 列的颜色一样,我们枚举 i + j 来容斥。这种具体写出并集式子的方法可以了解一下。

那么现在的问题是如何计算选了至少 i 行和 j 列时的方案数。当 i = j = 0 时,显然为 3 n 2 ;当 i j 其中之一为 0 时,先枚举是哪些行(或者列,下同),再枚举这些行是什么颜色,再枚举剩下的颜色,那么显然是 C n i 3 i 3 n ( n i ) ;当 i , j > 0 时,注意到那些同色的行和列颜色都相同,因此答案为 C n i C n j 3 3 ( n i ) ( n j )

不妨设:

f ( i , j ) = { 3 i 3 n ( n i ) i = 0 3 j 3 n ( n j ) j = 0 3 ( n i ) ( n j ) + 1 i , j > 0

那么最终答案是:
i = 0 n j = min ( 1 , i ) n C n i C n j f ( i , j ) ( 1 ) i + j + 1

注意上式 ( 1 ) 的指数。于是我们立刻得到一个 O ( n 2 ) 的做法。


我们不妨把 i = 0 或者 j = 0 的情况单独拿出来计算。由于总共只有 O ( n ) 项,因此这一步的时间复杂度为 O ( n ) 。剩下的是:

i = 1 n j = 1 n C n i C n j 3 ( n i ) ( n j ) + 1 ( 1 ) i + j + 1


很自然想到按照 i j 分开,把只与 i 有关的乘数移到外面去。在此之前,需要先把 ( n i ) ( n j ) 拆开。

i = 1 n j = 1 n C n i C n j 3 n 2 + 1 3 i n 3 j n 3 i j ( 1 ) i ( 1 ) j ( 1 )

把能提出去的都提出去:
( 1 ) 3 n 2 + 1 i = 1 n C n i 3 i n ( 1 ) i j = 1 n C n j 3 j n ( 1 ) j 3 i j

要是没有那个 3 i j ,这道题就做完啦!这个 3 i j 怎么解决呢?

观察右边的和式,整理一下得:

j = 1 n C n j ( 3 n + i ) j

这这么像二项式定理,就直接套用二项式定理咯。只不过 j 0 ,所以要减去 j = 0 的情况:
( 1 3 n + i ) n 1

现在的答案是:
( 1 ) 3 n 2 + 1 i = 1 n C n i 3 i n ( 1 ) i ( ( 1 3 n + i ) n 1 )

使用快速幂,时间复杂度 O ( n log n )

参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
typedef long long LL;
typedef unsigned long long ULL;
using std::cin;
using std::cout;
using std::endl;
typedef int INT_PUT;
INT_PUT readIn()
{
    INT_PUT a = 0; bool positive = true;
    char ch = getchar();
    while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
    if (ch == '-') { positive = false; ch = getchar(); }
    while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
    return positive ? -a : a;
}
void printOut(INT_PUT x)
{
    char buffer[20]; int length = 0;
    if (x < 0) putchar('-'); else x = -x;
    do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
    do putchar(buffer[--length]); while (length);
}

const int mod = 998244353;
LL power(LL x, int y)
{
    LL ret = 1;
    while (y)
    {
        if (y & 1) ret = ret * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ret;
}
const int inv3 = power(3, mod - 2);

const int maxn = int(1e6) + 5;
int n;
int fac[maxn];
int invFac[maxn];
int power1[maxn];
int power2[maxn];
int invPower1[maxn];
int invPower2[maxn];
void init()
{
    fac[0] = 1;
    for (int i = 1; i <= n; i++)
        fac[i] = (LL)fac[i - 1] * i % mod;
    invFac[n] = power(fac[n], mod - 2) % mod;
    for (int i = n - 1; ~i; i--)
        invFac[i] = (LL)invFac[i + 1] * (i + 1) % mod;
    power1[0] = 1;
    for (int i = 1; i <= n; i++)
        power1[i] = (LL)power1[i - 1] * 3 % mod;
    power2[0] = 1;
    for (int i = 1; i <= n; i++)
        power2[i] = (LL)power2[i - 1] * power1[n] % mod;
    invPower1[0] = 1;
    for (int i = 1; i <= n; i++)
        invPower1[i] = (LL)invPower1[i - 1] * inv3 % mod;
    invPower2[0] = 1;
    for (int i = 1; i <= n; i++)
        invPower2[i] = (LL)invPower2[i - 1] * invPower1[n] % mod;
}
inline LL C(int down, int up)
{
    return down < up ? 0 : (LL)fac[down] * invFac[up] % mod * invFac[down - up] % mod;
}

void run()
{
    n = readIn();
    init();

    LL ans = 0;
    for (int i = 1, sig = 1; i <= n; i++, sig = -sig)
        ans = (ans + (LL)C(n, i) * power1[i] % mod * power2[n - i] * sig) % mod;
    ans = (ans * 2) % mod;

    LL base = (LL)-power2[n] * 3 % mod;
    base = (base + mod) % mod;

    LL sum = 0;
    for (int i = 1, sig = -1; i <= n; i++, sig = -sig)
    {
        sum = (sum + (LL)C(n, i) * invPower2[i] * sig % mod *
            (power(1 - invPower1[n - i], n) - 1)) % mod;
    }
    ans = (ans + base * sum) % mod;
    ans = (ans + mod) % mod;
    printOut(ans);
}

int main()
{
    run();
    return 0;
}
总结

这道题本身不难,但还是看题解才做出来的,有几个原因:

  1. 出发点反了,没有想容斥,而是正难则反去了;
  2. 后面没有想到二项式定理。

也就是说这道题就是个容斥原理和二项式定理。二项式定理没什么好说的,注意它的结构。容斥原理可以把要算的内容具体的写出来,如果能写成并集形式就可以容斥了。

猜你喜欢

转载自blog.csdn.net/lycheng1215/article/details/80959978
sky