2020 江苏省赛 A. Array(线段树 + 欧拉降幂)

链接 A. Array

题意:
给你一个数组 , 有 5 种操作:

  1. 【l , r】区间每个数 加 k 。
  2. 【l , r】区间每个数 乘 k 。
  3. 【l , r】区间每个数变成 它的 k 次方。
  4. 求【l , r】区间每个数的 k 次方的 和(答案对p取模)。
  5. 求【l , r】区间所有数的乘积(答案对p取模)。

思路:

  1. 一眼看上去,实在是太难维护了,有加法,还有k次方,但观察到 p 很小,最大是 30,所以取模后数组中的每个数最大不超过 30。所以我们可以用 30个线段树维护区间内 [0 , 30)的个数。然后我们用 lazy 记录当前区间内每个数分别会变成什么。就可以实现 pushdown 操作了。
  2. 对区间更新的时候要新开一个数组,把每个数的个数先记录下来,再重新更新(可能有多个数变成同一个数)。
  3. 查询的时候不要写 30 个 query(会 T,因为找区间的过程是重复的,只要找一次就好了) ,以前也犯过这个错误,把要的值存下来 或者直接在里面操作就好了。
  4. 用快速幂常数有点大,所以再欧拉降幂一下。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 4e5+7;
int sum[32][maxn];
int lz[32][maxn];
int n,mod,p;
int a[maxn],q;
int num[32],cnt[32];
int phi;
int gcd(int a, int b){
    
    
   if(b == 0) return a;
   return gcd(b , a % b);
}
int eular(int n){
    
    
    int ans = n;
    for(int i=2; i * i <= n; ++i){
    
    
        if(n%i == 0){
    
    
            ans = ans/i * (i-1);
            while(n % i == 0){
    
    
                n /= i;
            }
        }
    }
    if(n > 1) ans = ans / n * (n - 1);
    return ans;
}
int poww(int a,int b){
    
    
       ll ans=1;
       if(gcd(a,p) == 1){
    
    
          b = b % phi;
       }
       else if(b >= phi){
    
    
            b = b % phi + phi;
       }
       while(b > 0){
    
    
            if(b & 1) ans = ans * a % p;
            a = a * a % p;
            b >>= 1;
       }
       return ans;
}
void pushup(int rt){
    
    
     for(int i = 0; i < p; i ++){
    
    
        sum[i][rt] = sum[i][rt << 1] + sum[i][rt << 1 | 1];
     }
}
void pushdown(int l,int r,int rt){
    
    
     for(int i = 0; i < p; i ++){
    
    
        num[i] = sum[i][rt << 1];
        sum[i][rt << 1] = 0;
     }
     for(int i = 0; i < p; i ++){
    
    
         sum[lz[i][rt]][rt << 1] += num[i];
     }
     for(int i = 0; i < p; i ++){
    
    
        num[i] = sum[i][rt << 1 | 1];
        sum[i][rt << 1 | 1] = 0;
     }
     for(int i = 0; i < p; i ++){
    
    
         sum[lz[i][rt]][rt << 1 | 1] += num[i];
     }
     for(int i = 0; i < p; i ++){
    
                        //lz 下放要特别注意,看lz会变成什么
        lz[i][rt << 1] = lz[lz[i][rt << 1]][rt];
        lz[i][rt << 1 | 1] = lz[lz[i][rt << 1 | 1]][rt];
     }
     for(int i = 0; i < p; i ++){
    
                        //lz 初始化为本身
        lz[i][rt] = i;
     }
}
void build(int l,int r,int rt){
    
    
    if(l == r){
    
    
        scanf("%d",&a[l]);
        sum[a[l] % p][rt] = 1;
        return ;
    }
    for(int i = 0; i < p; i ++){
    
    
        lz[i][rt] = i;
    }
    int mid = (l + r) / 2;
    build(l , mid, rt << 1);
    build(mid + 1, r , rt << 1 | 1);
    pushup(rt);
}
int ex(int x,int val,int id){
    
    
    if(id == 1) return (x + val) % p;
    if(id == 2) return ((ll)x * val) % p;      //这里注意会爆 int
    if(id == 3) return poww(x , val) % p;
}
void update(int id,int L,int R,int val,int l, int r, int rt){
    
    
     if(L <= l && R >= r){
    
    
        for(int i = 0; i < p; i ++){
    
    
            num[i] = sum[i][rt];
            sum[i][rt] = 0;
        }
        for(int i = 0; i < p; i ++){
    
    
            int x = ex(i,val,id);
            lz[i][rt] = ex(lz[i][rt],val,id);
            sum[x][rt] += num[i];
        }
        return ;
     }
     pushdown(l,r,rt);
     int mid = (l + r) / 2;
     if(L <= mid) update(id,L,R,val,l,mid,rt << 1);
     if(R > mid) update(id,L,R,val,mid + 1,r, rt << 1 | 1);
     pushup(rt);
}
void query(int L,int R,int l,int r,int rt){
    
    
    if(L <= l && R >= r){
    
    
        for(int i = 0; i < p; i ++){
    
    
            cnt[i] += sum[i][rt];
        }
        return ;
    }
    pushdown(l,r,rt);
    int mid = (l + r) / 2;
    if(L <= mid)  query(L,R,l,mid ,  rt << 1);
    if(R > mid )  query(L,R,mid + 1,r,rt << 1 | 1);
}
int main (){
    
    
    scanf("%d%d",&n,&p);
    phi = eular(p);
    int flag = 1;
    build(1,n,1);
    scanf("%d",&q);
    while(q--){
    
    
        int op,l,r,k;
        scanf("%d%d%d%d",&op,&l,&r,&k);
        if(op <= 3){
    
    
            update (op,l , r, k,1,n,1);
        }
        if(op == 4){
    
    
            int ans = 0;
            memset(cnt,0,sizeof(cnt));
            query(l,r,1,n,1);
            for(int i = 0; i < p; i ++){
    
    
                ans = (ans + poww(i,k) * cnt[i]) % p;
            }
            printf ("%d\n",ans);
        }
        if(op == 5){
    
    
            int ans = 1;
            memset(cnt,0,sizeof(cnt));
            query(l,r,1,n,1);
            for(int i = 0; i < p; i ++){
    
    
                ans = (ans * poww(i,cnt[i])) % p;
            }
            printf ("%d\n",ans);
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/hddddh/article/details/110520696