Codechef BINOMSUM

题意:(复制sunset的)有\(T\)天,每天有\(K\)个小时,第\(i\)天有\(D+i−1\)道菜,第一个小时你选择\(L\)道菜吃,接下来每个小时你可以选择吃一道菜或者选择\(A\)个活动中的一个参加,不能连续两个小时吃菜,问每天的方案数之和。\(K\),\(A\)预先给定,\(Q\)次询问,每次给\(D\),\(L\),\(T\)

题解:显然\(ans=\sum_{i=D}^{D+T-1}\binom{i}{L}F(i)\),其中\(F(i)\)是一个不超过\(k-1\)次的多项式。

把组合数暴力拆开,变为\(\sum_{i=D}^{D+T-1}\frac{i!}{L!(i-L)!}F(i)\)。因为有阶乘,所以考虑把\(F(i)\)写成上升幂多项式的形式来消掉阶乘。具体地,设\(F(x)=\sum_{i=0}^{k-1}a_i(x+1)\dots(x+i)=\sum_{i=0}^{k-1}a_i\frac{(x+i)!}{x!}\),则\(ans=\frac{1}{L!}\sum_{i=D}^{D+T-1}\sum_{j=0}^{k-1}a_j\frac{(i+j)!}{(i-L)!}\)。考虑在\(\frac{(i+j)!}{(i-L)!}\)的分母处补上\((j+L)!\)变为组合数,则\(ans=\frac{1}{L!} \sum_{j=0}^{k-1}a_j(j+L)!\sum_{i=D}^{D+T-1}\binom{i+j}{j+L}\)。后面是组合数上指标求和,可以\(O(1)\)计算。

剩下的问题是怎样求\(a\)。上升幂多项式可以考虑用连续点值来求。具体地,假设我们求出了\(F(-1),F(-2),\dots,F(-k)\),显然有式子\(F(-u)=\sum_{i=0}^{u-1}\frac{(u-1)!}{(u-1-i)!}(-1)^ia_i\)。设\(x_i=(-1)^ia_i,y_i=\frac{1}{i!},z_i=F(-(u+1))\),则\(Z=X*Y,X=\frac{Z}{Y}\)。多项式求逆即可。(其实可以不用求逆,可以发现\(Y=e^x,Y^{-1}=e^{-x}\)。)

剩下的问题是怎样求点值。设\(b_i\)为考虑了前\(i\)个小时的方案数,对于要求的点值\(x\),有递推式\(b_i=Ab_{i-1}+Axb_{i-2}\),可以用矩阵快速幂在\(O(\log k)\)的时间内求出单个点值。

//HNOIday1t1出题人nmsl
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N = 1e6 + 10;
const int M = 1e7 + 1e5 + 10;
const db pi = acos(-1);

int k, a, mod, q, l, r[N], fac[M], inv[M], ifac[M], x[N], y[N], z[N];

int gi() {
  int x = 0, o = 1;
  char ch = getchar();
  while((ch < '0' || ch > '9') && ch != '-') {
    ch = getchar();
  }
  if(ch == '-') {
    o = -1, ch = getchar();
  }
  while(ch >= '0' && ch <= '9') {
    x = x * 10 + ch - '0', ch = getchar();
  }
  return x * o;
}

struct com {
  db x, y;
  com(db x = 0, db y = 0): x(x), y(y) {}
  com operator+(const com &A) const {
    return com(x + A.x, y + A.y);
  }
  com operator-(const com &A) const {
    return com(x - A.x, y - A.y);
  }
  com operator*(const com &A) const {
    return com(x * A.x - y * A.y, x * A.y + y * A.x);
  }
  com conj() {
    return com(x, -y);
  }
} w[N];

void init(int n) {
  l = 0;
  for(int i = 1; i < n; i <<= 1) {
    ++l;
  }
  for(int i = 0; i < n; i++) {
    r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)), w[i] = com(cos(pi * i / n), sin(pi * i / n));
  }
}

void FFT(com *a, int n) {
  for(int i = 0; i < n; i++) if(i < r[i]) {
      swap(a[i], a[r[i]]);
    }
  for(int i = 1; i < n; i <<= 1)
    for(int p = i << 1, j = 0; j < n; j += p)
      for(int k = 0; k < i; k++) {
        com x = a[j + k], y = w[n / i * k] * a[j + k + i];
        a[j + k] = x + y, a[j + k + i] = x - y;
      }
}

void mul(int *a, int *b, int *c, int n) {
  static com s1[N], s2[N], s3[N], s4[N], s5[N], s6[N];
  init(n);
  for(int i = 0; i < n; i++) {
    s1[i] = com(a[i] & 32767, a[i] >> 15);
    s2[i] = com(b[i] & 32767, b[i] >> 15);
  }
  FFT(s1, n), FFT(s2, n);
  for(int i = 0; i < n; i++) {
    int j = (n - 1) & (n - i);
    com da = (s1[i] + s1[j].conj()) * com(0.5, 0);
    com db = (s1[i] - s1[j].conj()) * com(0, -0.5);
    com dc = (s2[i] + s2[j].conj()) * com(0.5, 0);
    com dd = (s2[i] - s2[j].conj()) * com(0, -0.5);
    s3[i] = da * dc, s4[i] = da * dd, s5[i] = db * dc, s6[i] = db * dd;
  }
  for(int i = 0; i < n; i++) {
    s1[i] = s3[i] + s4[i] * com(0, 1);
    s2[i] = s5[i] + s6[i] * com(0, 1);
  }
  FFT(s1, n), FFT(s2, n);
  reverse(s1 + 1, s1 + n), reverse(s2 + 1, s2 + n);
  for(int i = 0; i < n; i++) {
    int da = (ll)(s1[i].x / n + 0.5) % mod;
    int db = (ll)(s1[i].y / n + 0.5) % mod;
    int dc = (ll)(s2[i].x / n + 0.5) % mod;
    int dd = (ll)(s2[i].y / n + 0.5) % mod;
    c[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
  }
}

struct mat {
  int v[2][2];
  mat() {
    memset(v, 0, sizeof(v));
  }
  mat operator*(const mat &A) const {
    mat ret;
    for(int i = 0; i < 2; i++)
      for(int j = 0; j < 2; j++) {
        ull tmp = 0;
        for(int k = 0; k < 2; k++) {
          tmp += 1ll * v[i][k] * A.v[k][j];
        }
        ret.v[i][j] = tmp % mod;
      }
    return ret;
  }
} S, T;

mat qpow(mat a, int b) {
  mat ret;
  for(int i = 0; i < 2; i++) {
    ret.v[i][i] = 1;
  }
  while(b) {
    if(b & 1) {
      ret = ret * a;
    }
    a = a * a, b >>= 1;
  }
  return ret;
}

void init() {
  const int n = 1e7 + 1e5 + 1;
  fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
  for(int i = 2; i <= n; i++) {
    fac[i] = 1ll * fac[i - 1] * i % mod;
    inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    ifac[i] = 1ll * ifac[i - 1] * inv[i] % mod;
  }
}

int C(int n, int m) {
  if(m < 0 || n < m) {
    return 0;
  }
  return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

int main() {
#ifndef ONLINE_JUDGE
  freopen("a.in", "r", stdin);
  freopen("a.out", "w", stdout);
#endif
  cin >> k >> a >> mod >> q;
  init();
  S.v[0][0] = 1, S.v[0][1] = a, T.v[1][0] = 1, T.v[1][1] = a;
  for(int i = 0; i < k; i++) {
    T.v[0][1] = 1ll * a * (mod - i - 1) % mod;
    z[i] = 1ll * (S * qpow(T, k - 1)).v[0][0] * ifac[i] % mod;
    y[i] = 1ll * ((i & 1) ? mod - 1 : 1) * ifac[i] % mod;
  }
  int N = 1;
  while(N <= 2 * k - 2) {
    N <<= 1;
  }
  mul(y, z, x, N);
  for(int i = 0; i < k; i++) {
    x[i] = 1ll * x[i] * ((i & 1) ? mod - 1 : 1) % mod;
  }
  while(q--) {
    int l = gi(), d = gi(), t = gi(), ans = 0;
    for(int i = 0; i < k; i++) {
      ans = (ans + 1ll * x[i] * fac[i + l] % mod * (C(d + t + i, i + l + 1) - C(d + i, i + l + 1) + mod)) % mod;
    }
    cout << 1ll * ans*ifac[l] % mod << '\n';
  }
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/gczdajuruo/p/10921236.html