HDU 5977 Garden of Eden——点分治

版权声明:欢迎大家转载,转载请注明出处 https://blog.csdn.net/hao_zong_yin/article/details/83148121

上来按照dp的思想没什么头绪,因为5e4*(1<<10)有点大,所以往暴力上想了,树上暴力的话一般是往点分治上想,稍加思考发现这题只要枚举子集就可以在n(log(n))^2内解决,注意root是全局变量会改变,要存一下,因为这个直接自闭

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 5e4 + 10;
const int INF = 0x3f3f3f3f;
typedef long long LL;
int N, K, all, a[maxn], vis[maxn];
LL ans;
vector<int> G[maxn];
int sz[maxn], dp[maxn], root, SZ;
vector<int> sta;
LL num[1500];
void getroot(int f, int u) {
    sz[u] = 1, dp[u] = 0;
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f || vis[v]) continue;
        getroot(u, v);
        sz[u] += sz[v];
        dp[u] = max(dp[u], sz[v]);
    }
    dp[u] = max(dp[u], SZ - sz[u]);
    if (dp[u] < dp[root]) root = u;
}
void getsta(int f, int u, int s) {
    sta.push_back(s);
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f || vis[v]) continue;
        getsta(u, v, (s|(1<<a[v])));
    }
}
LL getans(int u, int s) {
    sta.clear();
    getsta(0, u, s);
    LL res = 0;
    for (int i = 0; i <= all; i++) num[i] = 0;
    for (int i = 0; i < sta.size(); i++) num[sta[i]]++;
    for (int i = 0; i < sta.size(); i++) {
        num[sta[i]]--;
        res += num[all];
        for (int s0 = sta[i]; s0; s0 = ((s0-1)&sta[i])) {
            res += num[all^s0];
        }
        num[sta[i]]++;
    }
    return res;
}
void solve(int u) {
    dp[0] = INF, root = 0;
    getroot(0, u);
    int s = (1<<a[root]);
    ans += getans(root, s);
    vis[root] = 1;
    int rt = root;
    for (int i = 0; i < G[rt].size(); i++) {
        int v = G[rt][i];
        if (vis[v]) continue;
        ans -= getans(v, (s|(1<<a[v])));
        SZ = sz[v];
        solve(v);
    }
}
int main() {
    while (~scanf("%d%d", &N, &K)) {
        for (int i = 0; i <= N; i++) G[i].clear();
        for (int i = 0; i <= N; i++) vis[i] = 0;
        all = (1<<K)-1;
        for (int i = 1; i <= N; i++) {
            scanf("%d", &a[i]);
            a[i] -= 1;
        }
        for (int i = 1; i < N; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        if (K == 1) {
            printf("%lld\n", 1LL*N*N);
            continue;
        }
        ans = 0;
        SZ = N;
        solve(1);
        printf("%lld\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/hao_zong_yin/article/details/83148121