[bzoj4504]K个串【可持久化线段树】【堆】

【题目链接】
  
【题解】
  首先记下每个点向右所控制的区间,就是它到下一个与它相同的位置-1。
  我们考虑对于每个左端点维护一棵线段树下标表示以该点为右端点的区间的答案。
  那么左端点为1的区间可以 O ( N ) 暴力求出。
  对于两个相邻的左端点 i , i + 1 ,只有 i 所控制的区间会减去 i 的值。用可持久化线段树+标记永久化即可。
  然后将每个点的对应最大值放入堆中,每次取出最大的并将该左端点的次大值放入。
  时间复杂度 O ( ( N + K ) l o g N )
【代码】

# include <bits/stdc++.h>
# define    ll      long long
# define    inf     0x3f3f3f3f
# define    N       100100
using namespace std;
int read(){
    int tmp = 0, fh = 1; char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') fh = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9'){tmp = tmp * 10 + ch - '0'; ch = getchar(); }
    return tmp * fh;
}

const ll infll = 0x3f3f3f3f3f3f3f3fll;
struct Tree{
    int pl, pr, id;
    ll mx, tag;
}T[N * 100];
struct Node{
    ll sum; int belong, id;
};
bool operator < (Node x, Node y){return x.sum < y.sum; }
priority_queue <Node> hp;
map <int, int> mp;
int n, place, k, nxt[N], num[N], cnt[N], rt[N];
ll sum[N], ans;
void pushup(int p){
    ll l = T[T[p].pl].mx + T[T[p].pl].tag, r = T[T[p].pr].mx + T[T[p].pr].tag;
    if (l > r) T[p].mx = l, T[p].id = T[T[p].pl].id;
        else T[p].mx = r, T[p].id = T[T[p].pr].id;
}
void build(int &p, int l, int r){
    p = ++place;
    if (l != r){
        int mid = (l + r) / 2;
        build(T[p].pl, l, mid);
        build(T[p].pr, mid + 1, r);
        pushup(p);
    }
    else {
        T[p].mx = sum[l];
        T[p].id = l;
    }
}
void modify(int &p, int las, int ql, int qr, ll x, int l, int r){
    p = ++place;
    T[p] = T[las];
    if (ql == l && qr == r){
        T[p].tag += x;
        return;
    }
    int mid = (l + r) / 2;
    if (mid >= qr) modify(T[p].pl, T[p].pl, ql, qr, x, l, mid);
        else if (mid < ql) modify(T[p].pr, T[p].pr, ql, qr, x, mid + 1, r);
            else modify(T[p].pl, T[p].pl, ql, mid, x, l, mid), 
                modify(T[p].pr, T[p].pr, mid + 1, qr, x, mid + 1, r);
    pushup(p);
}
Node query(int p, int x){
    return (Node){T[p].mx + T[p].tag, x, T[p].id};
}
int main(){
    //freopen("A.in", "r", stdin);
    //freopen("A.out", "w", stdout);
    n = read(), k = read();
    for (int i = 1; i <= n; i++) num[i] = read();
    for (int i = n; i >= 1; i--){
        if (mp.find(num[i]) != mp.end()) nxt[i] = mp[num[i]];
            else nxt[i] = n + 1;
        mp[num[i]] = i;
    }
    for (int i = 1; i <= n; i++) cnt[i] = n - i + 1;
    mp.clear();
    for (int i = 1; i <= n; i++){
        sum[i] = sum[i - 1];
        if (mp.find(num[i]) == mp.end()) sum[i] += num[i];
        mp[num[i]] = 1;
    }
    build(rt[1], 1, n);
    Node now = query(rt[1], 1);
    hp.push(now);
    for (int i = 2; i <= n; i++){
        rt[i] = rt[i - 1];
        if (i != nxt[i - 1]) modify(rt[i], rt[i], i, nxt[i - 1] - 1, -num[i - 1], 1, n);
        modify(rt[i], rt[i], i - 1, i - 1, -infll, 1, n); 
        Node now = query(rt[i], i);
        hp.push(now);
    }
    for (int i = 1; i <= k; i++){
        Node now = hp.top(); hp.pop();
        ans = now.sum;
        cnt[now.belong]--;
        if (cnt[now.belong] != 0){
            modify(rt[now.belong], rt[now.belong], now.id, now.id, -infll, 1, n);
            now = query(rt[now.belong], now.belong);
            hp.push(now);
        }
    //printf("%lld\n", ans);
    }
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/d_vanisher/article/details/80807504