EZOJ #77

传送门

分析

一个比较神奇的思路

我们考虑分治,对于每一个区间[le,ri]我们计算这个区间中左端点属于[le,mid],右端点属于[mid+1,ri]的情况对答案的贡献

我们求左半个区间的最大最小值的后缀信息以及右半个区间的最大最小值的前缀信息

于是我们发现在左半面最大值越来越小、最小值越来越大,右半面反之

于是我们枚举左端点,并由这个点i找到它在右半个区间对应的p和q

p表示右面最靠左的大于premax[i]的点,q表示右面最靠左的小于premin[i]的点

然后我们分p<=q和p>q两种情况统计答案即可

注意为了方便起见我们提前处理右半个区间最大值、最小值、最大值乘最小值这三个值的前缀和信息

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<queue>
#include<vector>
#include<set>
#include<map>
#include<stack>
using namespace std;
const int mod = 998244353;
int n,a[500100],premin[500100],surmin[500100],premax[500100],surmax[500100],Ans;
int s1[500100],s2[500100],s3[500100];
inline void go(int le,int ri){
    if(le==ri){
      Ans=(Ans+(long long)a[le]*a[le]%mod)%mod;
      return;
    }
    int i,mid=(le+ri)>>1,p=mid+1,q=mid+1;
    premin[mid]=premax[mid]=surmin[mid]=surmax[mid]=a[mid];
    for(i=mid-1;i>=le;i--){
      premin[i]=min(premin[i+1],a[i]);
      premax[i]=max(premax[i+1],a[i]);
    }
    for(i=mid+1;i<=ri;i++){
      surmin[i]=min(surmin[i-1],a[i]);
      surmax[i]=max(surmax[i-1],a[i]);
    }
    s1[mid]=s2[mid]=s3[mid]=0;
    for(i=mid+1;i<=ri;i++){
      s1[i]=(s1[i-1]+surmax[i])%mod;
      s2[i]=(s2[i-1]+surmin[i])%mod;
      s3[i]=(s3[i-1]+(long long)surmin[i]*surmax[i]%mod)%mod;
    }
    for(i=mid;i>=le;i--){
      while(surmax[p]<premax[i]&&p<=ri)p++;
      while(surmin[q]>premin[i]&&q<=ri)q++;
      int tot=0;
      if(p<=q){
          tot=(long long)(p-mid-1)*premin[i]%mod*premax[i]%mod;
          tot=(tot+(long long)premin[i]*((s1[q-1]-s1[p-1])%mod+mod)%mod)%mod;
        tot=(tot+((s3[ri]-s3[q-1])%mod+mod)%mod)%mod;
      }else {
          tot=(long long)(q-mid-1)*premin[i]%mod*premax[i]%mod;
          tot=(tot+(long long)premax[i]*((s2[p-1]-s2[q-1])%mod+mod)%mod)%mod;
        tot=(tot+((s3[ri]-s3[p-1])%mod+mod)%mod)%mod;
      }
      Ans=(Ans+tot)%mod;
    }
    go(le,mid),go(mid+1,ri);
    return;
}
int main(){
    int i,j,k;
    scanf("%d",&n);
    for(i=1;i<=n;i++)scanf("%d",&a[i]);
    go(1,n);
    printf("%d\n",Ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yzxverygood/p/9880197.html
77