2020牛客寒假算法基础集训营2 J-求函数 (线段树维护矩阵乘法)

题目链接:https://ac.nowcoder.com/acm/contest/3003/J

思路:

方法①

f1(1)=k1+b1=(k1)+(b1)

f2(f1(1))=k2(f1(1))+b2=k2k1+k2b1+b2=(k2k1)+(k2b1+b2)

f3(f2(f1(1)))=(k3k2k1)+(k3k2b1+k3b2+b3)

通过上面的展开,我们可以发现一个式子可以分成两部分:∏Ki  与  ∑ri=l(bi*∏rj=i+1Kj)

分别用线段树维护这两部分即可,现在考虑如果合并区[l, r] 与 [r1+1 ,r]

假设左区间的第一部分为 N第二部分为 M1

  右区间的第一部分为 N2 第二部分为 M2

合并后区间的第一部分为N1*N2,第二部分为N2 * M1 + M2

#include<iostream>
#include<algorithm>
#include<cstring>
 using namespace std;
 typedef long long ll;
 const int mod=1e9+7;
 const int maxn=2e5+10;
 struct node{
     ll l,r,k,b;
 }tree[maxn<<2];
 ll k[maxn],b[maxn],n,m,op,l1,r1,po,k1,b1;
 void pushup(int rt)
 {
     tree[rt].k=(tree[rt<<1].k*tree[rt<<1|1].k)%mod;
     tree[rt].b=((tree[rt<<1|1].k*tree[rt<<1].b)%mod+tree[rt<<1|1].b)%mod;
 }
 void build(ll rt,ll l,ll r)
 {
     tree[rt].l=l;
     tree[rt].r=r;
     if(l==r){
         tree[rt].k=k[l],tree[rt].b=b[l];
         return;
     }
    ll mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    pushup(rt);
 }
 void update(ll rt,ll pos)
 {
     if(tree[rt].l==pos&&tree[rt].r==pos){
         tree[rt].k=k[pos],tree[rt].b=b[pos];
         return;
     }
    ll mid=(tree[rt].l+tree[rt].r)>>1;
    if(pos<=mid)    update(rt<<1,pos);
    else update(rt<<1|1,pos);
    pushup(rt);
  }
  typedef pair<ll,ll> p;
  p query(int rt,int l,int r,int ll,int rr)
{
    if(ll>r || rr<l) return p(-1,-1);
    if(l>=ll && r<=rr) return p(tree[rt].k,tree[rt].b);
    int mid=(l+r)>>1;
    p p1=query(rt<<1,l,mid,ll,rr);
    p p2=query(rt<<1|1,mid+1,r,ll,rr);
    if(p1.first==-1) return p2;
    if(p2.first==-1) return p1;
    int k1=p1.first,b1=p1.second;
    int k2=p2.first,b2=p2.second;
    return p(1ll*k1*k2%mod,(1ll*b1*k2+b2)%mod);
}
 int main()
 {
     scanf("%lld%lld",&n,&m);
     for(int i=1;i<=n;i++) scanf("%d",&k[i]);
     for(int i=1;i<=n;i++) scanf("%d",&b[i]);
     build(1,1,n);
     for(int i=1;i<=m;i++){
         scanf("%lld",&op);
         if(op==1){
             scanf("%lld%lld%lld",&po,&k1,&b1);
            k[po]=k1;
            b[po]=b1;
            update(1,po);
         } 
         else{
             scanf("%lld%lld",&l1,&r1);
             p p1=query(1,1,n,l1,r1);
             int k=p1.first,b=p1.second;
            printf("%d\n",((k+b)%mod+mod)%mod);
         } 
     }
    return 0;
 }

猜你喜欢

转载自www.cnblogs.com/overrate-wsj/p/12274408.html