【经典】单调栈+离线+线段树区间更新——求所有子区间gcd之和

经典题经典折磨。。

#include<bits/stdc++.h>
using namespace std;
#define N 200005
#define mod 1000000007
#define ll long long 

ll n,a[N];

int L[N],R[N];
void prework(){
    stack<int>stk;
    a[n+1]=1e10;
    for(int i=1;i<=n+1;i++){
        if(!stk.size())stk.push(i);
        else {
            while(stk.size() && a[stk.top()]<=a[i]){
                int x=stk.top();stk.pop();
                R[x]=i-1;
            }
            stk.push(i);
        }
    }
    
    while(stk.size())stk.pop();
    a[0]=1e10;
    for(int i=n;i>=0;i--){
        if(!stk.size())stk.push(i);
        else {
            while(stk.size() && a[stk.top()]<a[i]){
                int x=stk.top();stk.pop();
                L[x]=i+1;
            }
            stk.push(i);
        }
    }
}

struct Query{//区间[l,r]的所有子区间gcd和 
    int sign;ll l,r,ans,a;
}q[N<<2];
int cmp(Query a,Query b){return a.l>b.l;}
int tot;

#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
ll sum[N<<2],lazy[N<<2];
void pushdown(int l,int r,int rt){
    if(lazy[rt]){
        int m=l+r>>1;
        lazy[rt<<1]+=lazy[rt];
        lazy[rt<<1|1]+=lazy[rt];
        sum[rt<<1]+=lazy[rt]*(m-l+1)%mod;sum[rt<<1]%=mod;
        sum[rt<<1|1]+=lazy[rt]*(r-m)%mod;sum[rt<<1|1]%=mod;
        lazy[rt]=0;
    }
}
void update(int L,int R,ll v,int l,int r,int rt){
    if(L>R)return;
    if(L<=l && R>=r){
        lazy[rt]+=v;sum[rt]+=1ll*(r-l+1)*v%mod;sum[rt]%=mod;
        return;
    }
    pushdown(l,r,rt);
    int m=l+r>>1;
    if(L<=m) update(L,R,v,lson);
    if(R>m)update(L,R,v,rson);
    sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod;
}
ll query(int L,int R,int l,int r,int rt){
    if(L>R)return 0;
    if(L<=l && R>=r)return sum[rt];
    pushdown(l,r,rt);
    int m=l+r>>1;
    ll res=0;
    if(L<=m)res+=query(L,R,lson);res%=mod;
    if(R>m)res+=query(L,R,rson);res%=mod;
    return res;
}

int main(){
    //freopen("35.in","r",stdin);
    cin>>n;
    for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
    prework();
    //for(int i=1;i<=n;i++)cout<<L[i]<<" "<<R[i]<<'\n';
    
    for(int i=1;i<=n;i++){
        Query t;
        t.l=L[i],t.r=R[i],t.sign=1;t.a=a[i];
        if(t.l<=t.r)q[++tot]=t;
        t.l=L[i],t.r=i-1,t.sign=-1;
        if(t.l<=t.r)q[++tot]=t;
        t.l=i+1,t.r=R[i],t.sign=-1;
        if(t.l<=t.r)q[++tot]=t;
    }
    sort(q+1,q+1+tot,cmp);
    
    int p=n+1;
    vector<pair<ll,ll> > last,now;
    for(int i=1;i<=tot;i++){
        while(p>q[i].l){
            --p;
            now.push_back(make_pair(a[p],p));
            for(auto x:last){
                int d=__gcd(a[p],x.first);
                if(d==now.back().first)continue;
                else now.push_back(make_pair(d,x.second));    
            }
            for(int j=0;j<now.size()-1;j++){
                int LL=now[j].second,RR=now[j+1].second-1;
                ll v=now[j].first;
                //cout<<LL<<" "<<RR<<" "<<v<<'\n';
                update(LL,RR,v,1,n,1);
            }
            int LL=now.back().second,RR=n;
            ll v=now.back().first;
            //cout<<LL<<" "<<RR<<" "<<v<<'\n';
            update(LL,RR,v,1,n,1);
            
            last=now;now.clear();
        }
        q[i].ans=query(q[i].l,q[i].r,1,n,1)*q[i].a%mod;
    }
    
    ll ans=0;
    for(int i=1;i<=tot;i++)
        ans=(ans+q[i].ans*q[i].sign+4ll*mod)%mod;
    cout<<ans<<'\n';
} 

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/12907887.html