分析
首先我们可以发现,题目中所说的“最优策略”,实际上就是每次不停地选尽量大的连通块,直到连通块中有一个节点上地果子被取完。此时这个节点相当于从树上被删去了,并且将它所在的连通块分割成了若干个小连通块。所以一个连通块能完整存在的时间,取决于连通块中 a i a_i ai的最小值。
于是我们先来考虑一种最简单的情况:如果一开始一条边都没有,我们所需的操作次数即为 ∑ i = 1 n a i \sum\limits_{i=1}^na_i i=1∑nai。
然后我们考虑加上一条边,假设此时这条边连接了 i i i和 j j j,那么根据上面的分析,这条边能为我们节省的操作次数应该为 m i n ( a i , a j ) min(a_i, a_j) min(ai,aj)。
然后根据题目中的连边方式,我们发现我们可以很方便地统计出一个点能连到的点的信息。具体来说,如果一个点 i i i向前连边,连到值比它大的,那么这条边的贡献即为 a i a_i ai;否则如果连到一个值比它小的 j j j,那么贡献为 a j a_j aj。由于每条边被连的概率相同,我们可以直接统计出总贡献,再乘上相应的概率。
那么问题就转换为一个数据结构题了:统计序列连续的一段中比给定数字大(或小)的数的和以及个数。那么我们可以用一棵权值线段树,通过动态加入和删除数字,保证区间范围符合题意即可。
代码
线段树操作有点多,常数较大,容易T。
#include <bits/stdc++.h>
#define P 998244353
#define ll long long
#define MAX 4000005
#define lc(x) (x<<1)
#define rc(x) (x<<1|1)
#define mid ((l+r)>>1)
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
template<typename T>
void read(T &n){
n = 0;
T f = 1;
char c = getchar();
while(!isdigit(c) && c != '-') c = getchar();
if(c == '-') f = -1, c = getchar();
while(isdigit(c)) n = n*10+c-'0', c = getchar();
n *= f;
}
template<typename T>
void write(T n){
if(n < 0) putchar('-'), n = -n;
if(n > 9) write(n/10);
putchar(n%10+'0');
}
void add(ll &x, ll y){
x += y;
if(x >= P) x -= P;
if(x < 0) x += P;
}
int n, m;
ll a[MAX], b[MAX];
ll s[MAX], t[MAX];
void update(int p, int l, int r, int u, ll k){
if(l == r){
add(s[p], b[l]*k);
add(t[p], k);
return;
}
if(mid >= u) update(lc(p), l, mid, u, k);
else update(rc(p), mid+1, r, u, k);
s[p] = (s[lc(p)]+s[rc(p)])%P;
t[p] = (t[lc(p)]+t[rc(p)])%P;
}
ll query1(int p, int l, int r, int ul, int ur){
if(l >= ul && r <= ur) return t[p];
ll res = 0;
if(mid >= ul) add(res, query1(lc(p), l, mid, ul, ur));
if(mid < ur) add(res, query1(rc(p), mid+1, r, ul, ur));
return res;
}
ll query2(int p, int l, int r, int ul, int ur){
if(l >= ul && r <= ur) return s[p];
ll res = 0;
if(mid >= ul) add(res, query2(lc(p), l, mid, ul, ur));
if(mid < ur) add(res, query2(rc(p), mid+1, r, ul, ur));
return res;
}
ll ans;
ll inv[MAX];
void init(){
inv[1] = 1;
for(int i = 2; i <= m; i++) inv[i] = (P-P/i)*inv[P%i]%P;
}
int main()
{
cin >> n >> m;
init();
for(int i = 1; i <= n; i++) read(a[i]), b[i] = a[i], ans += a[i];
sort(b+1, b+n+1);
int len = unique(b+1, b+n+1)-b-1;
for(int i = 1; i <= n; i++){
a[i] = lower_bound(b+1, b+len+1, a[i])-b;
}
update(1, 1, len+1, a[1], 1);
for(int i = 2; i <= n; i++){
int l = min(i-1, m);
ll x = query2(1, 1, len+1, 1, a[i]), y = query1(1, 1, len+1, a[i]+1, len+1);
ll sum = (x+y*b[a[i]]%P)%P;
add(ans, -sum*inv[l]%P);
update(1, 1, len+1, a[i], 1);
if(i-m > 0) update(1, 1, len+1, a[i-m], -1);
}
cout << ans << endl;
return 0;
}