【题目链接】
【思路要点】
- 首先考虑如何解决子任务2,也就是求解可行的方案数。
- 假设点\(A\)指向了其下方的点\(B\),那么其右方的点\(C\)的前驱就只能其上方的点。
- 由类似的一系列推理,我们发现矩阵的每一条副对角线上的元素方向应当相同。
- 考虑子任务2中\(N=M\)的情况。
- 在每一条副对角线上,我们无论是向下还是向右,我们都会走到下一条副对角线上。
- 因此我们实际上是确定了一个循环的向右/向下的操作序列,使得机器人能够遍历整个网格后回到起点。
- 假设有\(x\)个向下的操作,\(N-x\)个向右操作,不难发现当且仅当\(x\perp N\),操作序列合法。
- 因此答案为\(\sum_{i=1}^{N}[i\perp N]*\binom{N}{i}\)。
- 再来考虑\(N\ne M\)的情况,令\(d=gcd(N,M)\)。
- 我们发现每个\(d*d\)的正方形内部的方向排布是一样的。
- 通过分析(或打表),我们发现答案为\(\sum_{i=1}^{d}[i\perp d]*[i\perp N]*[(d-i)\perp M]*\binom{d}{i}\)。
- 接下来我们考虑原题。
- 由上文,每个\(d*d\)的正方形内部的方向排布是一样的,我们考虑枚举其中向下的个数\(i\),记向右的个数\(j=d-i\),那么首先应当有\([i\perp d]*[i\perp N]*[j\perp M]=1\)。
- 那么如果无视障碍,从\((x,y)\)出发,经过\(d\)步后,我们必然会走到\((i+x,j+y)\)。
- 有一点重要的转换:碰到障碍前走的路程=碰到每个障碍时走的路程的最小值。
- 我们记格子\((x,y)(1≤x≤i+1,1≤y≤j+1)\)的权值为走到有障碍的格子\((x+k*i,y+k*j)(k\in N)\)的最小步数。
- 那么问题就转化为了:分别求出\((1,1)\)到\((i+1,j+1)\)的每条只向右、向下走的路径上的最小值,并求和。
- 简单DP即可解决,令\(dp_{i,j,k}\)表示当前走到\((i,j)\),路径上最小值为\(k\),转移显然。
- 单次DP时间复杂度为\(O(N^4)\),总时间复杂度\(O(TN^5)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 55; const int P = 998244353; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, m, d, num[MAXN][MAXN]; int dp[MAXN][MAXN][MAXN * MAXN]; char mp[MAXN][MAXN]; void update(int &x, int y) {x = (x + y) % P; } int solve(int tn, int tm) { memset(dp, 0, sizeof(dp)); dp[1][1][num[1][1]] = 1; for (int i = 1; i <= tn; i++) for (int j = 1; j <= tm; j++) for (int k = 1; k <= n * m; k++) { int tmp = dp[i][j][k]; if (tmp == 0) continue; if (i < tn) update(dp[i + 1][j][min(k, num[i + 1][j])], tmp); if (j < tm) update(dp[i][j + 1][min(k, num[i][j + 1])], tmp); } int ans = 0; for (int i = 1; i <= n * m; i++) update(ans, 1ll * i * dp[tn][tm][i] % P); return ans; } int gcd(int x, int y) { if (y == 0) return x; else return gcd(y, x % y); } int main() { int T; read(T); while (T--) { read(n), read(m); for (int i = 1; i <= n; i++) scanf("\n%s", mp[i] + 1); d = gcd(n, m); int ans = 0; for (int i = 1; i <= d; i++) { int j = d - i; if (gcd(i, d) == 1 && gcd(i, n) == 1 && gcd(j, m) == 1) { for (int si = 1; si <= i + 1; si++) for (int sj = 1; sj <= j + 1; sj++) { int pi = si, pj = sj, stp = si + sj - 2; num[si][sj] = n * m; while (stp == si + sj - 2 || pi != si || pj != sj) { if (mp[pi][pj] == '1') { num[si][sj] = stp; break; } pi += i; if (pi > n) pi -= n; pj += j; if (pj > m) pj -= m; stp += d; } } ans = (ans + solve(i + 1, j + 1)) % P; } } writeln(ans); } return 0; }