「NOI2009」二叉查找树

传送门
Luogu

解题思路

看一眼题面,显然这是一颗 treap ,考虑到这棵 treap 的中序遍历总是不变的,所以我们就先把所有点按照数据值排序,求出 treap 的中序遍历,然后还可以观察到,点的权值并不直接参与答案的计算,所以我们还可以把点的权值离散化(毕竟 \(4e6\) 不是个小数字)。
然后我们就可以愉快的开始 \(\text{DP}\) 了。
由于树的中序遍历始终确定,所以我们很容易想到用区间 \(DP\) 来合并答案。
\(dp[l][r][k]\) 表示 \([l,r]\) 这段区间所有点的权值全都大于等于 \(k\) 的最小代价和。
我们可以枚举一个子树的根 \(x\) ,那么转移方程就是:
\(dp[l][r][k] = \min\left\{dp[l][x - 1][k] + dp[x + 1][r][k] + K + sum(l, r)\right\}\)
\(dp[l][r][k] = \min\left\{dp[l][x - 1][v_x] + dp[x + 1][r][v_x] + sum(l, r)\right\}\)
\(v_x\) 表示 \(x\) 的初始权值,\(sum(l, r)\) 表示 \([l, r]\) 这段区间的访问频度之和。
第一个方程表示更改 \(x\) 的权值为 \(k\) ,第二个表示不改。
由于每次向上合并答案时都会加上一遍整个区间的访问频度之和,所以就起到了乘以深度的效果。
一些初始化和小细节就不啰嗦了。

细节注意事项

  • 咕咕咕

参考代码

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <cctype>
#include <cmath>
#include <ctime>
#define rg register
using namespace std;
template < typename T > inline void read(T& s) {
    s = 0; int f = 0; char c = getchar();
    while (!isdigit(c)) f |= (c == '-'), c = getchar();
    while (isdigit(c)) s = s * 10 + (c ^ 48), c = getchar();
    s = f ? -s : s;
}

typedef long long LL;
const int _ = 70 + 2;

int n, K, X[_], sum[_]; LL dp[_][_][_];
struct node{ int d, v, a; }t[_];

inline bool cmp(const node& x, const node& y) { return x.v < y.v; }

inline bool Cmp(const node& x, const node& y) { return x.d < y.d; }

int main() {
#ifndef ONLINE_JUDGE
    freopen("in.in", "r", stdin);
#endif
    read(n), read(K);
    for (rg int i = 1; i <= n; ++i) read(t[i].d);
    for (rg int i = 1; i <= n; ++i) read(t[i].v);
    for (rg int i = 1; i <= n; ++i) read(t[i].a);
    sort(t + 1, t + n + 1, cmp);
    for (rg int i = 1; i <= n; ++i) t[i].v = i;
    sort(t + 1, t + n + 1, Cmp);
    for (rg int i = 1; i <= n; ++i)
        sum[i] = sum[i - 1] + t[i].a;
    memset(dp, 0x3f, sizeof dp);
    for (rg int k = 1; k <= n; ++k)
        for (rg int i = 0; i <= n; ++i)
            dp[i + 1][i][k] = 0;
    for (rg int i = 1; i <= n; ++i)
        for (rg int l = 1, r = l + i - 1; r <= n; ++l, ++r)
            for (rg int k = 1; k <= n; ++k)
                for (rg int x = l; x <= r; ++x) {
                    dp[l][r][k] = min(dp[l][r][k], dp[l][x - 1][k] + dp[x + 1][r][k] + K + sum[r] - sum[l - 1]);
                    int v = t[x].v;
                    if (v >= k)
                        dp[l][r][k] = min(dp[l][r][k], dp[l][x - 1][v] + dp[x + 1][r][v] + sum[r] - sum[l - 1]);
                }
    printf("%lld\n", dp[1][n][1]);
    return 0;
}

完结撒花 \(qwq\)

猜你喜欢

转载自www.cnblogs.com/zsbzsb/p/11746541.html