AGC 041 F

给你一个每列只剩下最底下 \(h_i\) 个的 \(N \times N\) 网格,求有多少种放置车的方案使得网格中每个点都被覆盖到。

\[1 \le N \le 400, 1 \le h_i \le N\]

考虑一个让原网格只剩下列的编号在 \(l\)\(r\) 之间且高度大于等于 \(h\) 的部分的子问题,如果最底层不连续,我们可以类似卷积的合并,所以只考虑底层连续的情况。

我们把列分成三类,有至少一个车的,没有车但是所有格子都被覆盖了的,以及没有车且至少一个格子未被覆盖的。

可以发现记这三个格子的数量就可以转移了,但是这样复杂度会爆炸。

观察到在高度小于 \(h\) 的行都有车时,第二类列和第一类是等价的,否则和第三类是等价的,于是我们可以只记一维,复杂度就优化到了 \(O(N ^ 3)\)

#include <bits/stdc++.h>

#define IL __inline__ __attribute__((always_inline))

#define For(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++ i)
#define FOR(i, a, b) for (int i = (a), i##end = (b); i < i##end; ++ i)
#define Rep(i, a, b) for (int i = (a), i##end = (b); i >= i##end; -- i)
#define REP(i, a, b) for (int i = (a) - 1, i##end = (b); i >= i##end; -- i)

typedef long long LL;

template <class T>
IL bool chkmax(T &a, const T &b) {
  return a < b ? ((a = b), 1) : 0;
}

template <class T>
IL bool chkmin(T &a, const T &b) {
  return a > b ? ((a = b), 1) : 0;
}

template <class T>
IL T mymax(const T &a, const T &b) {
  return a > b ? a : b;
}

template <class T>
IL T mymin(const T &a, const T &b) {
  return a < b ? a : b;
}

template <class T>
IL T myabs(const T &a) {
  return a > 0 ? a : -a;
}

const int INF = 0X3F3F3F3F;
const double EPS = 1E-8, PI = acos(-1.0);

#define DEBUG(...) fprintf(stderr, __VA_ARGS__)
#define OK DEBUG("Passing [%s] in LINE %d...\n", __FUNCTION__, __LINE__)
#define SZ(x) ((int)(x).size())

namespace Math {
const int MOD = 998244353;

IL int add(int a, int b) {
  a += b;
  return a >= MOD ? a - MOD : a;
}

template <class ...Args>
IL int add(int a, const Args &...args) {
  a += add(args...);
  return a >= MOD ? a - MOD : a;
}

IL int sub(int a, int b) {
  a -= b;
  return a < 0 ? a + MOD : a;
}

IL int mul(int a, int b) {
  return (LL)a * b % MOD;
}

template <class ...Args>
IL int mul(int a, const Args &...args) {
  return (LL)a * mul(args...) % MOD;
}

IL int quickPow(int a, int p) {
  int ret = 1;
  for (; p; p >>= 1, a = mul(a, a)) {
    if (p & 1) {
      ret = mul(ret, a);
    }
  }
  return ret;
}
}

using namespace Math;

const int MAXN = 400 + 5;

int h[MAXN], right[MAXN][MAXN], pref[MAXN][MAXN], binom[MAXN][MAXN], power[MAXN], f[MAXN][MAXN][MAXN][2];

int solve(int l, int r, int h) {
  if (pref[h][r] - pref[h][l - 1] < r - l + 1) {
    std::array<std::array<int, 2>, MAXN> g;
    For(i, 1, r - l + 1) {
      g[i][0] = g[i][1] = 0;
    }
    g[0][0] = g[0][1] = 1;
    int cur = 0;
    For(i, l, r) {
      if (::h[i] >= h && (i == l || ::h[i - 1] < h)) {
        solve(i, right[h][i], h);
        int cnt = right[h][i] - i + 1;
        Rep(j, cur + cnt, 0) {
          g[j][0] = mul(g[j][0], f[i][h][0][0]);
          g[j][1] = mul(g[j][1], f[i][h][0][1]);
          Rep(k, mymin(j, cnt), 1) {
            g[j][0] = add(g[j][0], mul(g[j - k][0], f[i][h][k][0]));
            g[j][1] = add(g[j][1], mul(g[j - k][1], f[i][h][k][1]));
          }
        }
        cur += cnt;
      }
    }
    For(i, 0, pref[h][r] - pref[h][l - 1]) {
      f[l][h][i][0] = g[i][0];
      f[l][h][i][1] = g[i][1];
    }
    return pref[h][r] - pref[h][l - 1];
  }
  int num = solve(l, r, h + 1), new_n = r - l + 1 - num;
  For(i, 0, num) {
    f[l][h][i][0] = add(f[l][h][i][0], f[l][h + 1][i][0]);
    f[l][h][i][1] = add(f[l][h][i][1], f[l][h + 1][i][0]);
    f[l][h][i][0] = add(f[l][h][i][0], mul(f[l][h + 1][i][0], sub(power[i], 1)));
    f[l][h][i + new_n][1] = add(f[l][h][i + new_n][1], mul(f[l][h + 1][i][1], sub(power[i + new_n], 1)));
    int cnt = (r - l + 1) - i;
    For(j, 1, cnt) {
      f[l][h][i + j][0] = add(f[l][h][i + j][0], mul(f[l][h + 1][i][0], binom[cnt][j], power[i]));
    }
    cnt -= new_n;
    For(j, 1, cnt) {
      f[l][h][i + j + new_n][1] = add(f[l][h][i + j + new_n][1], 
                                      mul(f[l][h + 1][i][1], binom[cnt][j], power[i + new_n]));
    }
  }
  return r - l + 1;
}

IL void init(int n) {
  binom[0][0] = 1;
  For(i, 1, n) {
    binom[i][0] = 1;
    For(j, 1, n) {
      binom[i][j] = add(binom[i - 1][j], binom[i - 1][j - 1]);
    }
  }
  power[0] = 1;
  For(i, 1, n) {
    power[i] = add(power[i - 1], power[i - 1]);
  }
}

int main() {
  int n;
  scanf("%d", &n);
  For(i, 1, n) {
    scanf("%d", &h[i]);
  }
  For(i, 1, n) {
    For(j, 1, n) {
      pref[i][j] = pref[i][j - 1] + (h[j] >= i);
    }
    right[i][n + 1] = n;
    Rep(j, n, 1) {
      if (h[j] < i) {
        right[i][j] = j - 1;
      } else {
        right[i][j] = right[i][j + 1];
      }
    }
  }
  init(n);
  solve(1, n, 1);
  printf("%d\n", f[1][1][n][1]);
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/sjkmost/p/12190191.html