【题解】 P6834 【[Cnoi2020]梦原】

分析

首先我们可以发现,题目中所说的“最优策略”,实际上就是每次不停地选尽量大的连通块,直到连通块中有一个节点上地果子被取完。此时这个节点相当于从树上被删去了,并且将它所在的连通块分割成了若干个小连通块。所以一个连通块能完整存在的时间,取决于连通块中 a i a_i ai的最小值。

于是我们先来考虑一种最简单的情况:如果一开始一条边都没有,我们所需的操作次数即为 ∑ i = 1 n a i \sum\limits_{i=1}^na_i i=1nai

然后我们考虑加上一条边,假设此时这条边连接了 i i i j j j,那么根据上面的分析,这条边能为我们节省的操作次数应该为 m i n ( a i , a j ) min(a_i, a_j) min(ai,aj)

然后根据题目中的连边方式,我们发现我们可以很方便地统计出一个点能连到的点的信息。具体来说,如果一个点 i i i向前连边,连到值比它大的,那么这条边的贡献即为 a i a_i ai;否则如果连到一个值比它小的 j j j,那么贡献为 a j a_j aj。由于每条边被连的概率相同,我们可以直接统计出总贡献,再乘上相应的概率。

那么问题就转换为一个数据结构题了:统计序列连续的一段中比给定数字大(或小)的数的和以及个数。那么我们可以用一棵权值线段树,通过动态加入和删除数字,保证区间范围符合题意即可。

代码

线段树操作有点多,常数较大,容易T。

#include <bits/stdc++.h>
#define P 998244353
#define ll long long
#define MAX 4000005
#define lc(x) (x<<1)
#define rc(x) (x<<1|1)
#define mid ((l+r)>>1)
using namespace std;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
template<typename T>
void read(T &n){
    
    
    n = 0;
    T f = 1;
    char c = getchar();
    while(!isdigit(c) && c != '-') c = getchar();
    if(c == '-') f = -1, c = getchar();
    while(isdigit(c)) n = n*10+c-'0', c = getchar();
    n *= f;
}
template<typename T>
void write(T n){
    
    
    if(n < 0) putchar('-'), n = -n;
    if(n > 9) write(n/10);
    putchar(n%10+'0');
}

void add(ll &x, ll y){
    
    
    x += y;
    if(x >= P) x -= P;
    if(x < 0) x += P;
}

int n, m;
ll a[MAX], b[MAX];
ll s[MAX], t[MAX];
void update(int p, int l, int r, int u, ll k){
    
    
    if(l == r){
    
    
        add(s[p], b[l]*k);
        add(t[p], k);
        return;
    }
    if(mid >= u) update(lc(p), l, mid, u, k);
    else update(rc(p), mid+1, r, u, k);
    s[p] = (s[lc(p)]+s[rc(p)])%P;
    t[p] = (t[lc(p)]+t[rc(p)])%P;
}
ll query1(int p, int l, int r, int ul, int ur){
    
    
    if(l >= ul && r <= ur) return t[p];
    ll res = 0;
    if(mid >= ul) add(res, query1(lc(p), l, mid, ul, ur));
    if(mid < ur) add(res, query1(rc(p), mid+1, r, ul, ur));
    return res;
}
ll query2(int p, int l, int r, int ul, int ur){
    
    
    if(l >= ul && r <= ur) return s[p];
    ll res = 0;
    if(mid >= ul) add(res, query2(lc(p), l, mid, ul, ur));
    if(mid < ur) add(res, query2(rc(p), mid+1, r, ul, ur));
    return res;
}

ll ans;
ll inv[MAX];

void init(){
    
    
    inv[1] = 1;
    for(int i = 2; i <= m; i++) inv[i] = (P-P/i)*inv[P%i]%P;
}

int main()
{
    
    
    cin >> n >> m;
    init();
    for(int i = 1; i <= n; i++) read(a[i]), b[i] = a[i], ans += a[i];
    sort(b+1, b+n+1);
    int len = unique(b+1, b+n+1)-b-1;
    for(int i = 1; i <= n; i++){
    
    
        a[i] = lower_bound(b+1, b+len+1, a[i])-b;
    }
    update(1, 1, len+1, a[1], 1);
    for(int i = 2; i <= n; i++){
    
    
        int l = min(i-1, m);
        ll x = query2(1, 1, len+1, 1, a[i]), y = query1(1, 1, len+1, a[i]+1, len+1);
        ll sum = (x+y*b[a[i]]%P)%P;
        add(ans, -sum*inv[l]%P);

        update(1, 1, len+1, a[i], 1);
        if(i-m > 0) update(1, 1, len+1, a[i-m], -1);
    }
    cout << ans << endl;

    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_30115697/article/details/109486417