Codeforces1111E Tree 题解

传送门

比赛的时候满脑子虚树 就死掉了

看来真的不能先入为主

这种询问点集的,首先考虑建虚树,然后按套路设 d p ( x , i ) dp(x,i) 表示将以 x x 为根的子树中的关键点分成 i i 组的方案数,那么你很快就会发现这玩意根本不能转移,因为它不是正常的背包。

换一个思路,如果我们考虑将关键点按 dfs \text{dfs} 序排序,那么对于一个点 node i \text{node}_i ,所有不能跟它放在同一组里的点(也就是它到根的路径上的点)都会排在它前面。不妨设点 x x 到根的路径上的关键点为 f ( x ) f(x) ,设 d p ( i , j ) dp(i,j) 表示前 i i 个点分成 j j 组的方案数,那么就有一个类似于第二类斯特林数的转移:

d p ( i , j ) = max { j f ( node i ) , 0 } d p ( i 1 , j ) + d p ( i 1 , j 1 ) dp(i,j)=\max\{j-f(\text{node}_i),0\}\cdot dp(i-1,j)+dp(i-1,j-1)

f ( x ) f(x) 的处理非常简单:先把每个点在树上做个标记,然后查询 x x 到根 r r 的路径上有多少个标记就行了(不要忘了减去自己),树剖+树状数组即可解决。

事实上我不是按照 dfs \text{dfs} 序排序的,是先处理了 f f 值然后直接按照 f f 从小到大排序,道理是一样的。

处理完询问不要忘了清空

#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>

template <typename T> inline void read(T& x) {
    int f = 0, c = getchar(); x = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
    if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
    read(x); read(args...); 
}
template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

const int maxn = 1e5 + 207;
const int maxm = 307;
const int mod = 1e9 + 7;

int v[maxn << 1], head[maxn], next[maxn << 1], tot;
int dep[maxn], size[maxn], top[maxn], son[maxn], fa[maxn], dfn[maxn], xys;
int n, q;

inline void ae(int x, int y) {
    v[++tot] = y; next[tot] = head[x]; head[x] = tot;
    v[++tot] = x; next[tot] = head[y]; head[y] = tot;
}
void dfs(int x) {
    size[x] = 1; dep[x] = dep[fa[x]] + 1;
    for (int i = head[x]; i; i = next[i])
        if (v[i] != fa[x]) {
            fa[v[i]] = x;
            dfs(v[i]);
            size[x] += size[v[i]];
            if (size[v[i]] > size[son[x]]) son[x] = v[i];
        }
}
void dfs(int x, int t) {
    top[x] = t; dfn[x] = ++xys;
    if (son[x]) dfs(son[x], t);
    for (int i = head[x]; i; i = next[i])
        if (v[i] != son[x] && v[i] != fa[x])
            dfs(v[i], v[i]);
}

int s[maxn];
inline void add(int x, int val) { for (; x <= n; x += x & -x) s[x] += val; }
inline int sum(int l, int r) {
    int ans = 0;
    for (; r; r -= r & -r) ans += s[r];
    for (--l; l; l -= l & -l) ans -= s[l];
    return ans;
}

inline void mark(int x) { add(dfn[x], 1); }
inline void del(int x) { add(dfn[x], -1); }
inline int query(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        ans += sum(dfn[top[x]], dfn[x]);
        x = fa[top[x]];
    }
    return ans + (dep[x] < dep[y] ? sum(dfn[x], dfn[y]) : sum(dfn[y], dfn[x]));
}

int dp[maxn][maxm];
int f[maxn];
int node[maxn], K, m, r;

inline void clear() {
    for (int i = 1; i <= K; ++i)
        for (int j = 1, _end = std::min(i, m); j <= _end; ++j)
            dp[i][j] = 0;
    for (int i = 1; i <= K; ++i) {
        f[node[i]] = 0;
        del(node[i]);
    }
}

int main() {
    read(n, q);
    for (int i = 1, x, y; i < n; ++i)
        read(x, y), ae(x, y);
    dfs(1); dfs(1, 1);
    while (q--) {
        read(K, m, r);
        for (int i = 1; i <= K; ++i) {
            read(node[i]);
            mark(node[i]);
        }
        for (int i = 1; i <= K; ++i)
            f[node[i]] = query(node[i], r) - 1;
        std::sort(node + 1, node + K + 1, [](int a, int b) -> bool { return f[a] < f[b]; });
        dp[1][1] = 1;
        for (int i = 2; i <= K; ++i)
            for (int j = 1, _end = std::min(i, m); j <= _end; ++j)
                if (j - f[node[i]] >= 0)
                    dp[i][j] = (1ll * dp[i - 1][j] * (j - f[node[i]]) % mod + dp[i - 1][j - 1]) % mod;
                else
                    dp[i][j] = dp[i - 1][j - 1];
        int ans = 0;
        for (int i = 1; i <= m; ++i)
            if ((ans += dp[K][i]) >= mod) ans -= mod;
        writeln(ans);
        clear();
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39677783/article/details/89432903