题解【洛谷P5658】[CSP-S 2019]括号树

题面

一道简单的栈与\(\text{DP}\)的结合。

首先介绍一下序列上的括号匹配问题,也就是此题在序列上的做法:

  • \(dp_i\) 表示以 \(i\) 结尾的合法的括号序列个数, \(ss_i\) 表示 \(1\)\(i\) 合法的括号序列字串个数。
  • 维护一个栈,左括号 \(\text{push}\) 它的位置到栈中,右括号取出栈顶 \(dp_i = dp_{sta[top] - 1} + 1\) , 然后 \(ss_i=ss_{i-1}+dp_{i}\)
  • 答案即为 \((1\times ss_1) \oplus (2 \times ss_2) \oplus \dots \oplus (n \times ss_n)\) ,其中 \(\oplus\) 为异或。

考虑将这个问题转移到树上,只需要一个可回退的栈即可。

这题真的不难,我考场上为什么没想出来啊

我太菜了

代码:

#include <bits/stdc++.h>
#define itn int
#define gI gi
#define int long long

using namespace std;

inline int gi()
{
    int f = 1, x = 0; char c = getchar();
    while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return f * x;
}

const int maxn = 500003;

int n, ans, topp, ss[maxn], dp[maxn], sta[maxn], sum[maxn];
int fa[maxn], tot, head[maxn], ver[maxn * 2], nxt[maxn * 2];
char s[maxn];

inline void add(int u, int v) {ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;}

void dfs(int u, int f)
{
    int fl = -1;
    if (s[u] == '(') sta[++topp] = u; //左括号加入栈
    else if (topp > 0) //右括号且栈中有对应的左括号
    {
        fl = sta[topp--]; //栈顶元素
        dp[u] = dp[fa[fl]] + 1; //dp 数组记得 +1
    }
    sum[u] = sum[fa[u]] + dp[u]; //sum[u] 表示 1 到 u 的路径上合法括号序列的个数
    for (int i = head[u]; i; i = nxt[i])
    {
        int v = ver[i];
        if (v == f) continue;
        dfs(v, u);
    }
    //将栈还原到访问节点 u 之前的状态
    if (s[u] == '(') --topp; 
    else if (fl != -1)
    {
        sta[++topp] = fl;
    }
}

signed main()
{
    //freopen(".in", "r", stdin);
    //freopen(".out", "w", stdout);
    n = gi();
    scanf("%s", s + 1);
    bool fl = true;
    for (int i = 2; i <= n; i+=1) 
    {
        fa[i] = gi();
        if (fa[i] != i - 1) fl = false;
        add(fa[i], i), add(i, fa[i]);
    }
    if (fl) //序列上的做法
    {
        for (int i = 1; i <= n; i+=1)
        {
            if (s[i] == '(') sta[++topp] = i;
            else if (topp) dp[i] = dp[sta[topp--] - 1] + 1;
            ss[i] = ss[i - 1] + dp[i];
            ans ^= (i * ss[i]); 
        }
        printf("%lld\n", ans);
        return 0;
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i+=1) ans ^= (i * sum[i]);
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/xsl19/p/12283223.html
今日推荐