「ZJOI2019」Minimax搜索-动态DP

Description

链接

Solution

考虑差分,求出 w ( S ) k w(S)\leq k 的集合数。

  • 显然, w w 的值域是连续的。所以 w w 只会变成 w + 1 w+1 w 1 w-1
  • 显然, w ( S ) 1 w(S)\leq 1 的集合数为 2 m 1 2^{m-1} ,所以接下来讨论的时候默认 w w 节点不会修改。
  • 可以发现 w ( s ) k w(s)\leq k 等价于把 S S 中所有 < w < w 的叶子的权值改为 w + 1 w+1 或把所有 > w > w 的权值改为 w + 1 w+1 (如果能改的话)后,根节点的权值改变。

考虑方案数不好合并,用概率替代。

f i f_i 表示 i i 节点的元素 < w < w 的概率, g i g_i 表示 i i 节点的元素 > w >w 的概率。

最终答案为 2 m 1 ( 1 f 1 ) ( 1 g 1 ) 2^{m-1}(1-f_1)(1-g_1)

转移如下

{ f u = v s o n u ( 1 f v ) f u = [ d e p u   m o d   2 = 0 ] ( 1 f u ) [ d e p u   m o d   2 = 1 ] f u \begin{cases} f_u'=\prod_{v \in son_u}(1-f_v') \\f_u'=[dep_u \ mod \ 2 =0](1-f_u)[dep_u \ mod \ 2 =1]f_u\end{cases}

{ g u = v s o n u ( 1 g v ) g u = [ d e p u   m o d   2 = 1 ] ( 1 g u ) [ d e p u   m o d   2 = 0 ] g u \begin{cases} g_u'=\prod_{v \in son_u}(1-g_v') \\g_u'=[dep_u \ mod \ 2 =1](1-g_u)[dep_u \ mod \ 2 =0]g_u\end{cases}

这样除了初值可以避免奇偶性的讨论,正确性感性理解就好了。

这样可以写 70 70 (不是我的)。

可以发现随着 R R 的增大,每次最多只会有两个叶子节点发生改变。所以维护 f i f_i' 之积,动态 D P DP 即可。

tips:由于会出现除 0 0 0 0 的情况,所以用数对 ( x , y ) (x,y) 表示原 D P DP x 0 y x\cdot 0^y 即可。

#include <bits/stdc++.h>
using namespace std;

inline int gi()
{
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    int sum = 0;
    while('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
    return sum;
}

typedef long long ll;
const int maxn = 200005, mod = 998244353, inv = (mod + 1) >> 1;

int n, m, L, LL, R, val[maxn];

struct edge
{
    int to, next;
} e[maxn * 2];
int h[maxn], Tot;

inline void add(int u, int v)
{
    e[++Tot] = (edge) {v, h[u]}; h[u] = Tot;
    e[++Tot] = (edge) {u, h[v]}; h[v] = Tot;
}

inline int fpow(int x, int k)
{
    int res = 1;
    while (k) {
        if (k & 1) res = (ll)res * x % mod;
        x = (ll)x * x % mod; k >>= 1;
    }
    return res;
}

int leaf[maxn], dfn[maxn], low[maxn], ord[maxn], fa[maxn], top[maxn], son[maxn], f[maxn][2], dep[maxn], siz[maxn], Time;

struct node
{
    int x, y;
    
    node() {x = y = 0;}
    node(int _x, int _y) {x = _x; y = _y;}
    
    inline node operator* (const node &a) const {
        return node((ll)x * a.x % mod, y + a.y);
    }
    inline node operator/ (const node &a) const {
        return node((ll)x * fpow(a.x, mod - 2) % mod, y - a.y);
    }
    inline int val() {return y ? 0 : x;}
} g[maxn][2];

inline node init(int a)
{
    if (a >= mod) a -= mod;
    return a ? node(a, 0) : node(1, 1);
}

void dfs1(int u)
{
    dep[u] = dep[fa[u]] + 1; siz[u] = 1; leaf[u] = 1;
    if (dep[u] & 1) val[u] = -1;
    else val[u] = n + 1;
    
    for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
        if (v != fa[u]) {
            fa[v] = u; leaf[u] = 0; dfs1(v);

            siz[u] += siz[v];
            if (siz[son[u]] < siz[v]) son[u] = v;
            if (dep[u] & 1) val[u] = max(val[u], val[v]);
            else val[u] = min(val[u], val[v]);
        }
    if (leaf[u]) val[u] = u, ++m;
}

void dfs2(int u)
{
    ord[dfn[u] = ++Time] = u;
    if (son[u]) top[son[u]] = top[u], dfs2(son[u]);
    else low[top[u]] = Time;
    for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
        if (v != fa[u] && v != son[u]) top[v] = v, dfs2(v);
    if (!leaf[u]) {
        f[u][0] = f[u][1] = 1; g[u][0] = g[u][1] = init(1);
        for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
            if (v != fa[u] && v != son[u]) {
                f[u][0] = (ll)f[u][0] * (mod + 1 - f[v][0]) % mod;
                f[u][1] = (ll)f[u][1] * (mod + 1 - f[v][1]) % mod;
                g[u][0] = g[u][0] * init(mod + 1 - f[v][0]);
                g[u][1] = g[u][1] * init(mod + 1 - f[v][1]);
            }
    } else {
        if (val[u] != val[1]) {
            if (val[u] < val[1]) f[u][0] = dep[u] & 1; 
            else if (val[u] - L < val[1]) f[u][0] = inv;
            else f[u][0] = (dep[u] & 1) ^ 1;
            if (val[u] > val[1]) f[u][1] = (dep[u] & 1) ^ 1;
            else if (val[u] + L > val[1]) f[u][1] = inv;
            else f[u][1] = dep[u] & 1;
        } else f[u][0] = (dep[u] & 1) ^ 1, f[u][1] = dep[u] & 1;
        g[u][0] = init(f[u][0]);
        g[u][1] = init(f[u][1]);
    }
    if (son[u]) {
        f[u][0] = (ll)f[u][0] * (mod + 1 - f[son[u]][0]) % mod;
        f[u][1] = (ll)f[u][1] * (mod + 1 - f[son[u]][1]) % mod;
    }
}

int tot, rt[maxn], sum[maxn << 2][2], pre[maxn << 2][2], lch[maxn << 2], rch[maxn << 2];

#define mid ((l + r) >> 1)

inline void merge(int &s, int len)
{
    sum[s][0] = (sum[lch[s]][0] + (ll)((len & 1) ? mod - pre[lch[s]][0] : pre[lch[s]][0]) * sum[rch[s]][0]) % mod;
    pre[s][0] = (ll)pre[lch[s]][0] * pre[rch[s]][0] % mod;
    sum[s][1] = (sum[lch[s]][1] + (ll)((len & 1) ? mod - pre[lch[s]][1] : pre[lch[s]][1]) * sum[rch[s]][1]) % mod;
    pre[s][1] = (ll)pre[lch[s]][1] * pre[rch[s]][1] % mod;
}

void build(int &s, int l, int r)
{
    s = ++tot;
    if (l == r) {
        sum[s][0] = pre[s][0] = g[ord[l]][0].val();
        sum[s][1] = pre[s][1] = g[ord[l]][1].val();
        return ;
    }
    build(lch[s], l, mid);
    build(rch[s], mid + 1, r);
    merge(s, mid - l + 1);
}

void modify(int &s, int l, int r, int p)
{
    if (l == r) {
        sum[s][0] = pre[s][0] = g[ord[l]][0].val();
        sum[s][1] = pre[s][1] = g[ord[l]][1].val();
        return ;
    }
    if (p <= mid) modify(lch[s], l, mid, p);
    else modify(rch[s], mid + 1, r, p);
    merge(s, mid - l + 1);
}

void modify(int u, int x)
{
    int v = top[u];
    if (fa[v]) g[fa[v]][x] = g[fa[v]][x] / init(mod + 1 - sum[rt[v]][x]);
    modify(rt[v], dfn[v], low[v], dfn[u]);
    if (fa[v]) g[fa[v]][x] = g[fa[v]][x] * init(mod + 1 - sum[rt[v]][x]), modify(fa[v], x);
}

int main()
{
    n = gi(); LL = L = gi(); R = gi();
    for (int i = 1; i < n; ++i) add(gi(), gi());

    dfs1(1);
    if (L == 1) printf("%d ", fpow(2, m - 1)), ++L, ++LL;
    else if (L > 2) --L;
    top[1] = 1; dfs2(1);

    for (int i = 1; i <= n; ++i)
        if (top[ord[i]] == ord[i])
            build(rt[ord[i]], dfn[ord[i]], low[ord[i]]);
    
    int pw = fpow(2, m - 1), lstans = 0, ans;
    lstans = (mod + 1 - (ll)(mod + 1 - sum[rt[1]][0]) * sum[rt[1]][1] % mod) * pw % mod;
    if (LL <= L && L <= R) printf("%d ", lstans);
    for (int i = L + 1; i <= R && i < n; ++i) {
        if (val[1] + i - 1 <= n && leaf[val[1] + i - 1]) {
            g[val[1] + i - 1][0] = init((mod + 1) >> 1);
            modify(val[1] + i - 1, 0);
        }
        if (val[1] - i + 1 >= 1 && leaf[val[1] - i + 1]) {
            g[val[1] - i + 1][1] = init((mod + 1) >> 1);
            modify(val[1] - i + 1, 1);
        }
        ans = (mod + 1 - (ll)(mod + 1 - sum[rt[1]][0]) * sum[rt[1]][1] % mod) * pw % mod;
        printf("%d ", (ans - lstans + mod) % mod);
        lstans = ans;
    }
    if (R == n) printf("%d ", (pw - lstans - 1 + mod) % mod);
    
    return 0;
}

猜你喜欢

转载自blog.csdn.net/DSL_HN_2002/article/details/89363763