luogu 4448 & XSY 3156

luogu 4448 & XSY 3156

题意:n个元素分为m类,求有多少合法排列,合法的定义是相邻元素不能为同类元素,\(n\leq 5\times10^3\)

XSY 3156 : 额外附加了第一个元素和最后一个元素是相邻的条件。

sol:

考虑容斥,于是变成求不合法的情况即至少有\(s\)个同种元素构成的部分,我们考虑如何对整个序列进行计数,假定每种元素有\(a_i\)个,把相邻的相同种类的元素合并为一个新元素,每种元素的新元素数量暂时定义为\(b_i\),考虑新元素的排列和容斥系数,明显对于\(a_i=b_i\)的时候合法方案被计算一次,因为系数为\((-1)^{a_i-b_i}\)
\[ \frac{(\sum_i b_i)!}{\prod_i b_i!}\prod_ia_i!\binom{a_i-1}{b_i-1}(-1)^{a_i-b_i} \]

因为容斥系数为\((-1)^{a_i-b_i}\),求积后为\(n-\sum_i b_i\),因此考虑钦点\(\sum_{i} b_i\)进行容斥

考虑对\(\sum_i b_i\)的合法方案计数,考虑\(f_{i,j}\)定义为前\(i\)种元素构成了\(j\)个新元素,则递推式为
\[ f_{i,j}=\sum_{k=1}^{\min(a_i,j)}f_{i-1,j-k}\times (-1)^{a_i-k}\frac{a_i!\binom{a_i-1}{k-1}}{k!} \]
实际上容斥系数在最后容斥中计算和在\(f\)中计算能产生同样的效果。

至于怎么做相邻的,8会。

复杂度为\(O((\sum_{i}a_i)^2)=O(n^2)\)

#include<algorithm>
#include<cstdio>
#include<cmath>
const int N = 3e2+7;
typedef long long LL;
#define R register
#define cerr(x) printf("%d\n", x)
const int p = 1e9+7;
int n;
int A[N*N], cnt = 0, k[N*N], sq;
LL fac[N*N], ifac[N*N];
bool dis[N][N];
inline int min(int a, int b) {
  return a > b ? b : a;
}
int vis[N*N];
inline LL FST(int x, int k) {
  LL res = 1;
  if (!k) return 1;
  if (!x) return 0;
  while (k) {
    if (k & 1) res = 1LL * res * x % p;
    x = 1LL * x * x % p, k >>= 1;
  } return res % p;
}
void dfs(int x, int col) {
  vis[x] = cnt;
  for (int i = 1; i <= n; i++)
    if (i != x && dis[x][i] && !vis[i])
      dfs(i, col);
}
void init() {
  scanf("%d", &n);
  for (R int i = 1; i <= n; i++)
    scanf("%d", &A[i]);
  std :: sort(A + 1, A + n + 1);
  for (R int i = 1; i <= n; i++)
    for (R int j = i + 1; j <= n; j++) {
      if (dis[i][j]) continue;
      LL nx = std :: sqrt(1LL * A[i] * A[j]);
      if (1LL * nx * nx == 1LL * A[i] * A[j])
        dis[i][j] |= 1, dis[j][i] |= 1;
    }
  for (R int i = 1; i <= n; i++)
    if (!vis[i]) dfs(i, ++cnt);
  for (R int i = 1; i <= n; i++) k[vis[i]]++;
  
  int up = 5000;
  fac[0] = 1, ifac[0] = 1;
  for (R int i = 1; i <= up; i++) fac[i] = 1LL * fac[i - 1] * i % p;
  ifac[up] = FST(fac[up], p - 2);
  for (R int i = up - 1; i >= 1; i--) ifac[i] = 1LL * ifac[i + 1] * (i + 1) % p;
  //for (R int i = 1; i <= n; i++) cerr(ifac[i]);
}
inline int cc(int n, int m) {
  return 1LL * fac[n] * ifac[n - m] % p * ifac[m] % p;
}
int f[N][N];
inline int iep(int a, int b) {
  return (((a - b) & 1) ? -1 : 1);
}
int main() {
  init();
  std :: sort(k + 1, k + cnt + 1);
  f[0][0] = 1;
  //for (R int i = 1; i <= cnt; i++) cerr(k[i]);
  for (R int i = 1, S = 0; i <= cnt; i++) {
    S += k[i];
    for (R int j = 1; j <= S; j++)
      for (R int K = 1; K <= k[i] && K <= j; K++)
        f[i][j] = (f[i][j] + (1LL * f[i - 1][j - K] * iep(k[i], K) * 
          cc(k[i] - 1, K - 1) % p * ifac[K] % p * fac[k[i]]) % p) % p;
  }
  LL ans = 0;
  for (R int i = cnt; i <= n; i++) 
    ans = (ans + 1LL * fac[i] * f[cnt][i] % p) % p;//, cerr(iep(n, i));
  printf("%lld\n", (ans % p + p) % p);
}

猜你喜欢

转载自www.cnblogs.com/cjc030205/p/11683238.html