分析
まず、タイトルに記載されている「最適な戦略」は、実際には、接続されたブロック内のノード上のフルーツが取り出されるまで、毎回最大の接続されたブロックを選択し続けることであることがわかります。このとき、このノードはツリーから削除されることに相当し、ノードが含まれる接続ブロックはいくつかの小さな接続ブロックに分割されます。したがって、接続されたブロックが存在する時間は、接続されたブロックのaia_iに完全に依存しますA私最小値。
したがって、最初に最も単純なケースを考えてみましょう。最初にエッジがない場合、必要な操作の数は∑ i = 1 nai \ sum \ Limits_ {i = 1} ^ na_iです。i = 1∑n個A私。
次に、このエッジがiiに接続されていると仮定して、エッジを追加することを検討します。iとjjjの場合、上記の分析によれば、このエッジで節約できる操作の数はmin(ai、aj)min(a_i、a_j)である必要があります。m i n (a私、AJ)。
次に、タイトルの接続方法に従って、ポイントが接続できるポイントの情報を簡単に数えることができることがわかりました。具体的には、ポイントiiの場合iはエッジを前方に接続し、それより大きい値に接続します。このエッジの寄与はaia_iです。A私;それ以外の場合は、それよりも小さい値jjに接続されている場合jの場合、寄与はaj a_jAJ。各エッジは接続される確率が同じであるため、寄与の合計を直接カウントし、対応する確率を掛けることができます。
次に、問題はデータ構造の問題に変換されます。統計シーケンスの連続セグメント内の指定された数よりも大きい(または小さい)数の合計と数です。次に、加重線セグメントツリーを使用して数値を動的に追加および削除し、間隔の範囲が質問の意味を満たすようにします。
コード
ラインセグメントツリーの操作は少し多く、定数は大きく、Tは簡単です。
#include <bits/stdc++.h>
#define P 998244353
#define ll long long
#define MAX 4000005
#define lc(x) (x<<1)
#define rc(x) (x<<1|1)
#define mid ((l+r)>>1)
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
template<typename T>
void read(T &n){
n = 0;
T f = 1;
char c = getchar();
while(!isdigit(c) && c != '-') c = getchar();
if(c == '-') f = -1, c = getchar();
while(isdigit(c)) n = n*10+c-'0', c = getchar();
n *= f;
}
template<typename T>
void write(T n){
if(n < 0) putchar('-'), n = -n;
if(n > 9) write(n/10);
putchar(n%10+'0');
}
void add(ll &x, ll y){
x += y;
if(x >= P) x -= P;
if(x < 0) x += P;
}
int n, m;
ll a[MAX], b[MAX];
ll s[MAX], t[MAX];
void update(int p, int l, int r, int u, ll k){
if(l == r){
add(s[p], b[l]*k);
add(t[p], k);
return;
}
if(mid >= u) update(lc(p), l, mid, u, k);
else update(rc(p), mid+1, r, u, k);
s[p] = (s[lc(p)]+s[rc(p)])%P;
t[p] = (t[lc(p)]+t[rc(p)])%P;
}
ll query1(int p, int l, int r, int ul, int ur){
if(l >= ul && r <= ur) return t[p];
ll res = 0;
if(mid >= ul) add(res, query1(lc(p), l, mid, ul, ur));
if(mid < ur) add(res, query1(rc(p), mid+1, r, ul, ur));
return res;
}
ll query2(int p, int l, int r, int ul, int ur){
if(l >= ul && r <= ur) return s[p];
ll res = 0;
if(mid >= ul) add(res, query2(lc(p), l, mid, ul, ur));
if(mid < ur) add(res, query2(rc(p), mid+1, r, ul, ur));
return res;
}
ll ans;
ll inv[MAX];
void init(){
inv[1] = 1;
for(int i = 2; i <= m; i++) inv[i] = (P-P/i)*inv[P%i]%P;
}
int main()
{
cin >> n >> m;
init();
for(int i = 1; i <= n; i++) read(a[i]), b[i] = a[i], ans += a[i];
sort(b+1, b+n+1);
int len = unique(b+1, b+n+1)-b-1;
for(int i = 1; i <= n; i++){
a[i] = lower_bound(b+1, b+len+1, a[i])-b;
}
update(1, 1, len+1, a[1], 1);
for(int i = 2; i <= n; i++){
int l = min(i-1, m);
ll x = query2(1, 1, len+1, 1, a[i]), y = query1(1, 1, len+1, a[i]+1, len+1);
ll sum = (x+y*b[a[i]]%P)%P;
add(ans, -sum*inv[l]%P);
update(1, 1, len+1, a[i], 1);
if(i-m > 0) update(1, 1, len+1, a[i-m], -1);
}
cout << ans << endl;
return 0;
}