[HAOI2018]染色——二项式反演

题面

  LOJ2527

解析

  设出现$S$次的颜色至少有$i$种的方案数为$f_i$,钦定$i$种颜色出现$S$次,剩下的任选:$f_i=\binom{m}{i}*\frac{n!}{(S!)^i(n-iS)!}*(m-i)^{n-iS}$,其中$\frac{n!}{(S!)^i(n-iS)!}$表示在$n$个位置种选$n-iS$个位置填$i$种颜色,每种颜色填$S$次的方案数。

  设$g_i$表示设出现$S$次的颜色恰好有$i$种的方案数,然后会发现$f_i=\sum_{j=i}\binom{j}{i}g_j$

  于是二项式反演可得:$$\begin{align*}g_i&=\sum_{j=i}(-1)^{j-i}\binom{j}{i}f_j\\&=\frac{1}{i!}\sum_{j=i}\frac{(-1)^{j-i}}{(j-i)!}*j!*f_j\end{align*}$$

  卷积即可

  $O(M \log M)$

 代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn = 200005, mod = 1004535809, g = 3;

inline int read()
{
    int ret, f=1;
    char c;
    while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1;
    ret=c-'0';
    while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0';
    return ret*f;
}

int add(int x, int y)
{
    return x + y < mod? x + y: x + y - mod;
}

int rdc(int x, int y)
{
    return x - y < 0? x - y + mod: x - y;
}

ll qpow(ll x, int y)
{
    ll ret = 1;
    while(y)
    {
        if(y&1)
            ret = ret * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ret;
}

int n, m, s, lim, bit, rev[maxn<<1];
int fac[10000005], fnv[10000005], a[maxn];
ll ginv, f[maxn<<1], h[maxn<<1];

void init()
{
    int t = max(n, m);
    ginv = qpow(g, mod - 2);
    fac[0] = 1;
    for(int i = 1; i <= t; ++i)
        fac[i] = 1LL * fac[i-1] * i % mod;
    fnv[t] = qpow(fac[t], mod - 2);
    for(int i = t - 1; i >= 0; --i)
        fnv[i] = 1LL * fnv[i+1] * (i + 1) % mod;
}

int comb(int x, int y)
{
    if(x < y || y < 0)    return 0;
    return (1LL * fac[x] * fnv[y] % mod) * fnv[x-y] % mod;
}

void NTT_init(int x)
{
    lim = 1;
    bit = 0;
    while(lim <= x)
    {
        lim <<= 1;
        ++ bit;
    }
    for(int i = 1; i < lim; ++i)
        rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1));
}

void NTT(ll *x, int y)
{
    for(int i = 1; i < lim; ++i)
        if(i < rev[i])
            swap(x[i], x[rev[i]]);
    ll wn, w, u, v;
    for(int i = 1; i < lim; i <<= 1)
    {
        wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1));
        for(int j = 0; j < lim; j += (i << 1))
        {
            w = 1;
            for(int k = 0; k < i; ++k)
            {
                u = x[j+k];
                v = x[j+k+i] * w % mod;
                x[j+k] = add(u, v);
                x[j+k+i] = rdc(u, v);
                w = w * wn % mod;
            }
        }
    }
    if(y == -1)
    {
        ll linv = qpow(lim, mod - 2);
        for(int i = 0; i < lim; ++i)
            x[i] = x[i] * linv % mod;
    }
}

int main()
{
    n = read(); m = read(); s = read();
    for(int i = 0; i <= m; ++i)
        a[i] = read();
    int sj = min(m, n / s);
    init();
    for(int i = 0; i <= sj; ++i)
        f[i] = (((1LL * comb(m, i) * fac[n] % mod) * qpow(fnv[s], i) % mod) * fnv[n-i*s] % mod) * qpow(m - i, n - i * s) % mod;
    for(int i = 0; i <= sj; ++i)
    {
        f[i] = f[i] * fac[i] % mod;
        h[sj-i] = ((i & 1)? rdc(0, fnv[i]): fnv[i]);
    }
    NTT_init(sj << 1);
    NTT(f, 1);
    NTT(h, 1);
    for(int i = 0; i < lim; ++i)
        f[i] = f[i] * h[i] % mod;
    NTT(f, -1);
    int ans = 0;
    for(int i = 0; i <= sj; ++i)
        ans = add(ans, (a[i] * f[i+sj] % mod) * fnv[i] % mod);
    printf("%d", ans);
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Joker-Yza/p/12676726.html