[ZJOI 2010] 排列计数

[题目链接]

          https://www.lydsy.com/JudgeOnline/problem.php?id=2111

[算法]

        一种比较好的理解方式是将该序列看成是一棵堆式存储的二叉树

        那么问题转化为求有多少个堆

        考虑dp , 用fi表示以i为根的子树能构成多少个堆

        根结点显然是最小的数 , 我们要在剩余的(sizei - 1)个数中选出size(2i)个数 , 然后分配至左右子树中

        显然 , fi = C(sizei - 1 , size(2i)) * f(2i) * f(2i + 1)

        预处理阶乘和逆元 , 用lucas定理求组合数即可

        时间复杂度 : O(N)

[代码]

        

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int N = 2e6 + 10;

int n , P;
int fac[N] , inv[N] , size[N];

template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); }
template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); }
template <typename T> inline void read(T &x)
{
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
inline int exp_mod(int a , int n)
{
        int res = 1 , b = a;
        while (n > 0)
        {
                if (n & 1) res = 1ll * res * b % P;
                b = 1ll * b * b % P;
                n >>= 1;
        }
        return res;
}
inline void init()
{
        fac[0] = 1;
        for (int i = 1; i <= min(n , P - 1); i++) fac[i] = 1ll * fac[i - 1] * i % P;
        inv[min(n , P - 1)] = exp_mod(fac[min(n , P - 1)] , P - 2);
        for (int i = min(n , P - 1) - 1; i >= 0; i--) inv[i] = 1ll * inv[i + 1] * (i + 1) % P;
}
inline int C(int x , int y)
{
        if (!y) return 1;
        if (x == y) return 1;
        if (x < y) return 0;
        return 1ll * fac[x] * inv[y] % P * inv[x - y] % P;
}
inline int lucas(int x , int y)
{
        if (!y) return 1;
        if (x < P && y < P) return C(x , y);
        return 1ll * lucas(x / P , y / P) * C(x % P , y % P) % P;
}
inline int dp(int u)
{
        if (u > n) return 1;
        return 1ll * lucas(size[u] - 1 , size[u << 1]) * dp(u << 1) % P * dp(u << 1 | 1) % P;        
}

int main()
{
        
        read(n); read(P);
        init();
        for (int i = 1; i <= n; i++) size[i] = 1;
        for (int i = n; i >= 1; i--) size[i >> 1] += size[i];
        printf("%d\n" , dp(1));
        
        return 0;
    
}

猜你喜欢

转载自www.cnblogs.com/evenbao/p/10459677.html