传送门
题目大意
有一个 的正方形网格。用三种颜色对每个格子染色,求有多少种染色方案使得至少一行或一列是同一种颜色。答案对 取模。
思路
正难则反,考虑求不存在一行或一列同色的方案数。结果反着也难……
正着做,考虑容斥。这里有行和列两个参数,如何容斥呢?我们定义
表示第
行的颜色都相同的染色方案,用
表示第
行的颜色都相同的染色方案,那么我们要求的是:
把上面这个式子具体地写出来,你就知道该如何容斥了:设至少有 行的颜色一样,至少有 列的颜色一样,我们枚举 来容斥。这种具体写出并集式子的方法可以了解一下。
那么现在的问题是如何计算选了至少 行和 列时的方案数。当 时,显然为 ;当 和 其中之一为 时,先枚举是哪些行(或者列,下同),再枚举这些行是什么颜色,再枚举剩下的颜色,那么显然是 ;当 时,注意到那些同色的行和列颜色都相同,因此答案为 。
不妨设:
那么最终答案是:
注意上式 的指数。于是我们立刻得到一个 的做法。
我们不妨把
或者
的情况单独拿出来计算。由于总共只有
项,因此这一步的时间复杂度为
。剩下的是:
很自然想到按照
和
分开,把只与
有关的乘数移到外面去。在此之前,需要先把
拆开。
把能提出去的都提出去:
要是没有那个 ,这道题就做完啦!这个 怎么解决呢?
观察右边的和式,整理一下得:
这这么像二项式定理,就直接套用二项式定理咯。只不过 ,所以要减去 的情况:
现在的答案是:
使用快速幂,时间复杂度 。
参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
typedef long long LL;
typedef unsigned long long ULL;
using std::cin;
using std::cout;
using std::endl;
typedef int INT_PUT;
INT_PUT readIn()
{
INT_PUT a = 0; bool positive = true;
char ch = getchar();
while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
if (ch == '-') { positive = false; ch = getchar(); }
while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
return positive ? -a : a;
}
void printOut(INT_PUT x)
{
char buffer[20]; int length = 0;
if (x < 0) putchar('-'); else x = -x;
do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
do putchar(buffer[--length]); while (length);
}
const int mod = 998244353;
LL power(LL x, int y)
{
LL ret = 1;
while (y)
{
if (y & 1) ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
const int inv3 = power(3, mod - 2);
const int maxn = int(1e6) + 5;
int n;
int fac[maxn];
int invFac[maxn];
int power1[maxn];
int power2[maxn];
int invPower1[maxn];
int invPower2[maxn];
void init()
{
fac[0] = 1;
for (int i = 1; i <= n; i++)
fac[i] = (LL)fac[i - 1] * i % mod;
invFac[n] = power(fac[n], mod - 2) % mod;
for (int i = n - 1; ~i; i--)
invFac[i] = (LL)invFac[i + 1] * (i + 1) % mod;
power1[0] = 1;
for (int i = 1; i <= n; i++)
power1[i] = (LL)power1[i - 1] * 3 % mod;
power2[0] = 1;
for (int i = 1; i <= n; i++)
power2[i] = (LL)power2[i - 1] * power1[n] % mod;
invPower1[0] = 1;
for (int i = 1; i <= n; i++)
invPower1[i] = (LL)invPower1[i - 1] * inv3 % mod;
invPower2[0] = 1;
for (int i = 1; i <= n; i++)
invPower2[i] = (LL)invPower2[i - 1] * invPower1[n] % mod;
}
inline LL C(int down, int up)
{
return down < up ? 0 : (LL)fac[down] * invFac[up] % mod * invFac[down - up] % mod;
}
void run()
{
n = readIn();
init();
LL ans = 0;
for (int i = 1, sig = 1; i <= n; i++, sig = -sig)
ans = (ans + (LL)C(n, i) * power1[i] % mod * power2[n - i] * sig) % mod;
ans = (ans * 2) % mod;
LL base = (LL)-power2[n] * 3 % mod;
base = (base + mod) % mod;
LL sum = 0;
for (int i = 1, sig = -1; i <= n; i++, sig = -sig)
{
sum = (sum + (LL)C(n, i) * invPower2[i] * sig % mod *
(power(1 - invPower1[n - i], n) - 1)) % mod;
}
ans = (ans + base * sum) % mod;
ans = (ans + mod) % mod;
printOut(ans);
}
int main()
{
run();
return 0;
}
总结
这道题本身不难,但还是看题解才做出来的,有几个原因:
- 出发点反了,没有想容斥,而是正难则反去了;
- 后面没有想到二项式定理。
也就是说这道题就是个容斥原理和二项式定理。二项式定理没什么好说的,注意它的结构。容斥原理可以把要算的内容具体的写出来,如果能写成并集形式就可以容斥了。