JZOJ4196. 二分图计数(容斥+计数)

[JZOJ4196. 二分图计数
  • 题意:
    • 给定左边一排\(n\)个点,右边一排\(m\)个点的二分图,且左边第\(i\)个点向右边除了第\(a_i\)个点外都连一条边.

    • 要求每次在左边选一集合\(S\),记\(F(S)\)表示在点集\(S\)所连向的边中选一些边使得两两配对,\(G(S)=\sum_{i\in S}2^i\),求\[\sum_{S}F(S)G(S)\]

    • \(n\le 16, m\le 10^9, 0\le a_i\lt m\)

  • 首先发现\(G(S)\)没有任何用,我们算出\(F(S)\)之后直接乘一下即可.

  • 假设没有非法边,那么\(F(S)\)显然是\[m*(m-1)*\cdots *(m-n+1)\]

  • 有非法边,则考虑容斥,有\[F(S) = \sum_{S_1\in S} (-1)^{|S_1|}H(S,S_1)\]

  • 其中\(H(S,S_1)\)表示\(S_1\)点集移向非法边的点集(非法边就是一个点唯一没有连向的那条边),且\(\{S-S_1\}\)中的点随意连边的方案.

  • 值得注意的一个地方是,我们选子集的时候,一个非法点不能被连两次,否则答案会算多.

  • 如果遇到已经被选的非法点,则直接跳过选下一个点.

  • 最后注意随意连边那一段可以用个逆元搞一下.

  • 时间复杂度\(O(3^n)\)(每个点只有三种状态:①不选,②选了且在\(S_1\)里,③选了不在\(S_1\)里)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>

#define ll long long
#define L register ll
#define F(i, a, b) for (L i = a; i <= b; i ++)
#define N 30

using namespace std;

const int mo = 1e9 + 7;

ll n, m, ans, answer, fir, a[N], d[N], jc[N], ny[N];
struct node{ int num, v; } num[N];
bool bz[N];

void Go(L k, L step, L sum, L n, L tot, L God) {
    if (k > n) {
        answer = (answer + ((jc[m - sum - fir] * ny[m - (n - tot) - sum - fir]) % mo) * step * God) % mo;
        return;
    }
    if (!bz[a[d[k]]]) bz[a[d[k]]] = 1, Go(k + 1, step * (- 1), sum + 1, n, tot + 1, God), bz[a[d[k]]] = 0;
    Go(k + 1, step, sum, n, tot, God);
}

void dfs(L k, L step, L sum) {
    if (k > n) {
        Go(1, 1, 0, step - 1, 0, sum);
        return;
    }
    d[step] = k, dfs(k + 1, step + 1, (sum + (1 << (k - 1)) % mo)), d[step] = 0, dfs(k + 1, step, sum);
}

bool cmp(node x, node y) { return x.v < y.v; }

ll ksm(L x, L y) {
    L ans = 1;
    while (y) {
        if (y & 1) ans = (ans * x) % mo;
        x = (x * x) % mo, y >>= 1;
    }
    return ans;
}

int main() {
    freopen("bipartite.in", "r", stdin);
    freopen("bipartite.out", "w", stdout);
    scanf("%lld%lld", &n, &m), fir = max(m - 2 * n, 0LL), ny[0] = jc[0] = 1;
    F(i, 1, n) scanf("%lld", &a[i]), num[i] = {i, a[i]};
    F(i, fir + 1, m) jc[i - fir] = (jc[i - fir - 1] * i) % mo, ny[i - fir] = ksm(jc[i - fir], mo - 2);
    sort(num + 1, num + n + 1, cmp), num[0].v = ans = -1;
    F(i, 1, n) ans += (num[i].v != num[i - 1].v), a[num[i].num] = ans;
    dfs(1, 1, 0), printf("%lld\n", answer);
}

猜你喜欢

转载自www.cnblogs.com/Pro-king/p/9383528.html