【THUPC 2017】小L的计算题

Problem

Description

现有一个长度为 \(n\) 的非负整数数组 \(\{a_i\}\)。小 L 定义了一种神奇变换:

\[f_k=\sum_{i=1}^{n}{a_i}^k\pmod{ 998244353 } \]

小 L 计划用变换生成的序列 \(f\) 做一些有趣的事情,但是他并不擅长算乘法,所以来找你帮忙,希望你能帮他尽快计算出 \(f_1\sim f_n\)

总共有 \(T\) 组数据。

Range

\(n\le 2\times10^5, T\le20, \sum n\le 4\times10^5, a_i\le 10^9\)

Algorithm

生成函数,多项式

Mentality

写出生成函数 \(F\) 的表达式:

\[F=\sum_{k} f_kx^k\\ =\sum_{k} \sum_{i=1}^n a_i^kx^k\\ =\sum_{i=1}^n \sum_{k} (a_ix)^k\\ =\sum_{i=1}^n \frac{1}{1-a_ix}\\ =n-x\sum_{i=1}^n \frac{-a_ix}{1-a_ix} \]

然后发现 \((ln(1-a_ix))'=\frac{-a_ix}{1-a_ix}\) ,直接代入:

\[F=n-x\sum_{i=1}^n (ln(1-a_ix))'\\ =n-x(\sum_{i=1}^n ln(1-a_ix))'\\ =n-x(ln(\prod_{i=1}^n (1-a_ix)))' \]

用分治计算 \(\prod\) ,然后求个 \(ln\) 就完事了。

Code

#include <cmath>
#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
#define LL long long
#define go(G, x, i, v) \
  for (int i = G.hd[x], v = G.to[i]; i; v = G.to[i = G.nx[i]])
#define inline __inline__ __attribute__((always_inline))
inline LL read() {
  LL x = 0, w = 1;
  char ch = getchar();
  while (!isdigit(ch)) {
    if (ch == '-') w = -1;
    ch = getchar();
  }
  while (isdigit(ch)) {
    x = (x << 3) + (x << 1) + ch - '0';
    ch = getchar();
  }
  return x * w;
}

const int Max_n = 4e6 + 5, mod = 998244353;
int T;
bool fl;
int n, cnt, a[Max_n];
vector<int> f[1000000], ans;

namespace Input {
void main() {
  n = read();
  for (int i = 1; i <= n; i++) a[i] = read();
}
}  // namespace Input

namespace Poly {
int len, bit, rev[Max_n];
int ksm(int a, int b = mod - 2) {
  int res = 1;
  for (; b; b >>= 1, a = (LL)a * a % mod)
    if (b & 1) res = (LL)res * a % mod;
  return res;
}
void init(int n) {
  len = 1 << (bit = log2(n) + 1);
  for (int i = 0; i < len; i++)
    rev[i] = rev[i >> 1] >> 1 | ((i & 1) << bit - 1);
}
void dft(vector<int> &f, bool t) {
  for (int i = 0; i < len; i++)
    if (rev[i] > i) swap(f[i], f[rev[i]]);
  for (int l = 1; l < len; l <<= 1) {
    int Wn = ksm(3, (mod - 1) / (l << 1));
    if (t) Wn = ksm(Wn);
    for (int i = 0; i < len; i += l << 1) {
      int Wnk = 1;
      for (int j = i; j < i + l; j++, Wnk = (LL)Wnk * Wn % mod) {
        int x = f[j], y = (LL)f[j + l] * Wnk % mod;
        f[j] = (x + y) % mod, f[j + l] = (x - y + mod) % mod;
      }
    }
  }
  if (t)
    for (int i = 0, Inv = ksm(len); i < len; i++) f[i] = (LL)f[i] * Inv % mod;
}
void Resize(vector<int> &f, int len) {
  f.resize(len);
  for (int i = 0; i < len; i++) f[i] = 0;
}
void Mul(vector<int> f, vector<int> &g, vector<int> &res, int N) {
  init(N);
  static vector<int> G;
  Resize(res, len), Resize(G, len);
  for (int i = 0; i < min((int)f.size(), len); i++) res[i] = f[i];
  for (int i = 0; i < min((int)g.size(), len); i++) G[i] = g[i];
  dft(res, 0), dft(G, 0);
  for (int i = 0; i < len; i++) res[i] = (LL)res[i] * G[i] % mod;
  dft(res, 1);
}
void Inv(vector<int> &f, vector<int> &res, int N) {
  init(N * 6);
  Resize(res, len);
  static vector<int> F;
  Resize(F, len);
  res[0] = ksm(f[0]);
  for (int deg = 2; deg < (N << 1); deg <<= 1) {
    init(deg * 3);
    for (int i = 0; i < min(deg, (int)f.size()); i++) F[i] = f[i];
    for (int i = min(deg, (int)f.size()); i < len; i++) F[i] = 0;
    dft(F, 0), dft(res, 0);
    for (int i = 0; i < len; i++)
      res[i] = (2ll * res[i] % mod + mod - (LL)res[i] * res[i] % mod * F[i] % mod) % mod;
    dft(res, 1);
    for (int i = deg; i < len; i++) res[i] = 0;
  }
}
void Ln(vector<int> &f, vector<int> &res, int N) {
  static vector<int> inv;
  res = f;
  for (int i = 0; i < N; i++) res[i] = (LL)res[i + 1] * (i + 1) % mod;
  res[N] = 0, Inv(f, inv, N);
  Mul(res, inv, res, N + N);
}
}  // namespace Poly
using namespace Poly;

namespace Solve {
void Solve(int o, int l, int r) {
  if (l == r) {
    f[o].resize(2);
    f[o][0] = 1, f[o][1] = (-a[l] % mod + mod) % mod;
    return;
  }
  int mid = l + r >> 1;
  Solve(o << 1, l, mid), Solve(o << 1 | 1, mid + 1, r);
  Mul(f[o << 1], f[o << 1 | 1], f[o], r - l + 2);
}
void main() {
  Solve(1, 1, n);
  for (int i = n + 1; i < len; i++) f[1][i] = 0;
  fl = 1;
  Ln(f[1], ans, n + 1);
  int Ans = 0;
  for (int i = 0; i < n; i++) Ans ^= (-ans[i] + mod) % mod;
  cout << Ans << endl;
}
}  // namespace Solve

int main() {
#ifndef ONLINE_JUDGE
  freopen("2409.in", "r", stdin);
  freopen("2409.out", "w", stdout);
#endif
  T = read();
  while (T--) {
    Input::main();
    Solve::main();
  }
}

猜你喜欢

转载自www.cnblogs.com/luoshuitianyi/p/12891662.html