【NOI 2018】冒泡排序(组合计数 + 动态规划)

题目链接:【NOI 2018】冒泡排序

题目大意:给定一个排列 p ,求字典序严格大于 p 的,最长下降子序列长度不超过 3 的排列个数 mod 998244353 的值。

题目等价于:能划分成两个上升子序列的序列个数。

假设在前 i 个位置中,最大值是 k ,我们发现在余下的数中, > k 的可以随便放置,而 < k 的数就只能从小到大依次放置,构成另一个上升子序列。

先不考虑字典序的问题。设 f i , j 表示放了 i 个数,剩下的数中有 j 个大于当前最大值的方案数。如果新加进来的数小于最大值,则 j 不变;如果新加进来的数大于最大值,设它是第 k 个大于最大值的元素,则 j 会减少 k

于是: f i , j = f i 1 , 0 + k = 1 j f i 1 , j k = k = 0 j f i 1 , j k ( i > 0 ) f 0 , 0 = 1
我们发现这个式子就是个前缀和: g i , j = g i 1 , j + g i , j 1 ( i , j > 0 ) g i , 0 = 1 , g i , 1 = i
之后不难推得: g i , j = C i + j 1 j = C i + j 1 j 2

接下来考虑字典序限制。考虑到第 i 位,填入的数字是 a i ,后面有 c n t 个数 > 之前的最大值。 p i 后面有 b i 个数比它大, p i 前面有 c i 给数比它小,这两个数组可以用树状数组方便地计算出来。

我们先计算 a i > p i 的情况。首先 c n t m i n ( c n t , b i ) ,因为填完第 i 位后 c n t 不可能 > b i 。然后此时如果 c n t = 0 ,就代表最大的数被填进去了,这意味着我们后面将只能按顺序依次填入,而这个排列的字典序是严格不大于 p 的,因此就可以退出了。否则,我们就用 g n i + 1 , c n t 1 更新答案。

接下来考虑 a i 是否可以等于 p i 。如果刚刚 b i 更新了 c n t ,说明 p i 本身 > 最大值,当然合法;如果 c i = p i 1 ,就说明 p i 是剩下元素中最小的,填这个数也合法;否则就是乱序填入剩下的数,不合法,直接退出。

时间复杂度 Θ ( n log n )

#include <cstdio>
#include <cstring>
const int m = 1200000;
const int maxn = 600005;
const int maxm = 1200005;
const int mod = 998244353;
int T, n, a[maxn], b[maxn], c[maxn], bit[maxn];
int fac[maxm], inv[maxm];
void add(int x) {
    for (; x <= n; x += x & -x) bit[x]++;
}
int sum(int x) {
    int y = 0;
    for (; x; x -= x & -x)  y += bit[x];
    return y;
}
int mpow(int x, int y) {
    int z = 1;
    for (; y; y >>= 1) {
        if (y & 1)  z = 1ll * z * x % mod;
        x = 1ll * x * x % mod;
    }
    return z;
}
void prework() {
    fac[0] = 1;
    for (int i = 1; i <= m; i++)    fac[i] = 1ll * i * fac[i - 1] % mod;
    inv[m] = mpow(fac[m], mod - 2);
    for (int i = m - 1; ~i; i--)   inv[i] = 1ll * (i + 1) * inv[i + 1] % mod;
}
int comb(int n, int m) {
    return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int solve(int n, int m) {
    return !m ? 1 : !(m ^ 1) ? n : (comb(n + m - 1, m) - comb(n + m - 1, m - 2) + mod) % mod;
}
int main() {
    prework();
    for (scanf("%d", &T); T--; ) {
        scanf("%d", &n);
        memset(bit, 0, sizeof(bit));
        for (int i = 1; i <= n; i++) {
            scanf("%d", a + i), add(a[i]);
            c[i] = sum(a[i] - 1);
            b[i] = n - a[i] - (i - 1 - c[i]);
        }
        int ans = 0, cnt = n;
        for (int i = 1; i <= n; i++) {
            bool flag = b[i] < cnt;
            if (flag)   cnt = b[i];
            if (!cnt)   break;
            ans = (ans + solve(n - i + 1, cnt - 1)) % mod;
            if (!flag && c[i] != a[i] - 1)  break;
        }
        printf("%d\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42068627/article/details/81114823