[NOI2010]超级钢琴(堆+ST表)

题意

给定一个长度为n的序列,找出k个连续的区间,使得和最大,区间长度在[l,r]之间。

对于100%的数据,-1000 ≤ Ai ≤ 1000,n,k<=500000.

题解

首先贪心:肯定是选取最大的k个区间,区间和可以用前缀和处理,问题是怎么找出,如果是暴力加入的话会达到$n^{2}$。

于是出现一种仙法

考虑固定左端点i,区间和就是sum[k]-sum[i],要和最大,就只要sum[k]最大,k有一个合法范围,初始时对于每个左端点找出右端点在合法区间内最大的一个放入堆。然后在选取k个区间的时候,取出堆顶的时候,把右端点的合法区间分成了两部分,再放入堆。这样就保证了一定会找到这k个区间,因为他总有左端点(滑稽),在分裂右端点区间时就保证了相同左端点的区间也可能被选到。每次选取最多会分裂成2个再放入堆,所以最多放入n+2k个。

那么考虑如何找到最大的sum[k],只需要查询区间最大就好了,线段树也可以,只是不用修改的话ST表貌似会好一点。

放入堆的元素是定义的结构体,包括左端点,右端点区间和最优的右端点。

#include<bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn= 500005;
int n,k,L,R;
ll a[maxn],sum[maxn];
int st[maxn][25],lg[maxn];//取最大值的地方
struct cx{
  int start,l,r,end;//区间开始,右端点的范围,右端点内取最值的地方
  cx(int start,int l,int r,int end) : start(start),l(l),r(r),end(end) {}
  friend bool operator < (const cx &a,const cx &b){return sum[a.end]-sum[a.start-1]<sum[b.end]-sum[b.start-1];}
};

priority_queue<cx> q;

template<class T>inline void read(T &x){
  x=0;int f=0;char ch=getchar();
  while(!isdigit(ch)) {f|=(ch=='-');ch=getchar();}
  while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
  x= f ? -x : x;
}

void RMQ(){
  for(int i=1;i<=n;i++) st[i][0]=i;
  for(int j=1;(1<<j)<=n;j++)
   for(int i=1;i+(1<<j)-1<=n;i++){
     int x=st[i][j-1],y=st[i+(1<<(j-1))][j-1];
     st[i][j]= sum[x]>sum[y] ? x : y;
   }
}

int query(int l,int r){
  int o=lg[r-l+1],x=st[l][o],y=st[r-(1<<o)+1][o];
  //printf("%d %d %d %d %d %d\n",l,r,o,1<<o,x,y);
  return sum[x]>sum[y] ? x : y;
}

int main(){
  //freopen("testdata.in","r",stdin);
  read(n);read(k);read(L);read(R);
  lg[0]=-1;
  for(int i=1;i<=n;i++)
   lg[i]=lg[i>>1]+1;
  //for(int i=1;i<=n;i++) printf("%d ",lg[i]);
  for(int i=1;i<=n;i++){
    read(a[i]);
    sum[i]=sum[i-1]+a[i];
  }
  RMQ();
  for(int i=1;i+L-1<=n;i++){
    int l=i+L-1,r=min(i+R-1,n),x=query(l,r);
    //printf("%d %d %d %d \n",i,l,r,x);
    q.push(cx(i,l,r,x));
  }

  ll ans=0;
  for(int i=1;i<=k;i++){
    cx get=q.top();
    q.pop();
    ans+=sum[get.end]-sum[get.start-1];
    //printf("%d %d\n",get.start,get.end);
    if(get.end!=get.l) q.push(cx(get.start,get.l,get.end-1,query(get.l,get.end-1)));
    if(get.end!=get.r) q.push(cx(get.start,get.end+1,get.r,query(get.end+1,get.r)));
  }
  printf("%lld",ans);
}
View Code

猜你喜欢

转载自www.cnblogs.com/sto324/p/11291497.html