CodeForces - 1312D. Count the Arrays 组合数学

CodeForces - 1312D. Count the Arrays

原题地址:

http://codeforces.com/contest/1312/problem/D

基本题意:

计算满足下列条件的数组的数量:

  1. 数组有n个元素
  2. 数组中元素的值在1 - m 之间
  3. 对于每个数组有且仅有一对相等的元素(我个英语智障,这个exactly一直没搞懂)
  4. 数组满足存在一个位置 i 在 i 左边严格递增,右边严格递减

将答案对 998244353 取模;

基本思路:

  1. 我们要找能同时满足条件 (3)(4)的数组,那么我们只要在 [1 , m] 中挑找 n - 1 个数将它们严格递增排列;
  2. 然后为了满足条件(3)要再在除了最高的那个之外即剩下的 n - 2 个中挑一个作为重复的那一个放在右边;
  3. 最后在剩下的 n-3个里可以任意挑出 i 个放在右边(i = 1,2,3,… ,n-3)。

所以综上答案为 :

   mod_comb(m, n-1,mod) * mod_comb(n - 2, 1,mod) * 2 ^ (n-3) % mod

参考代码:

#include <bits/stdc++.h>
using namespace std;
#define IO std::ios::sync_with_stdio(false)
#define int long long
#define INF 0x3f3f3f3f

const int maxn = 2e5+10;
int fact[maxn];
inline int qsm(int x,int n,int mod) {
    int res = 1;
    while (n > 0) {
        if (n & 1) res = res * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return res;
}
inline int extgcd(int a,int b,int &x,int &y) {
    int d = a;
    if (b != 0) {
        d = extgcd(b, a % b, y, x);
        y -= (a / b) * x;
    } else {
        x = 1;
        y = 0;
    }
    return d;
}
inline int mod_inverse(int a,int m) {
    int x, y;
    extgcd(a, m, x, y);
    return (m + x % m) % m;
}
inline int mod_fact(int n,int p,int &e) {
    e = 0;
    if (n == 0) return 1;
    int res = mod_fact(n / p, p, e);
    e += n / p;

    if (n / p % 2 != 0) return res * (p - fact[n % p] % p);
    return res * fact[n % p] % p;
}
inline int mod_comb(int n,int k,int p) {
    if (n < 0 || k < 0 || n < k) return 0;
    int e1, e2, e3;
    int a1 = mod_fact(n, p, e1), a2 = mod_fact(k, p, e2), a3 = mod_fact(n - k, p, e3);
    if (e1 > e2 + e3) return 0;
    return a1 * mod_inverse(a2 * a3 % p, p) % p;
}
const int mod = 998244353;
int n,m;
signed main() {
    IO;
    fact[0] = 0, fact[1] = 1;
    for (int i = 2; i <= maxn; i++) fact[i] = fact[i - 1] * i % mod;
    cin >> n >> m;
    if (n == 2) {
        cout << 0 << endl;
        return 0;
    }
    int ans = mod_comb(m, n - 1, mod);
    ans %= mod;
    ans *= n - 2;
    ans %= mod;
    ans *= qsm(2, n - 3, mod);
    ans %= mod;
    cout << ans << endl;
    return 0;
}

发布了12 篇原创文章 · 获赞 5 · 访问量 485

猜你喜欢

转载自blog.csdn.net/weixin_44164153/article/details/104794987