2020 CCPC 威海 - G Caesar Cipher (线段树 + hash)

链接 : G Caesar Cipher

题意 :
给定一个数组 ,范围为 [0,65536),有以下两种操作:

  1. 给出 x , y 把 [x , y] 内的每个数 + 1 同时对 65536 取模。
  2. 给出 x,y,L , 查询区间 [x , x + L - 1] 和区间 [y , y + L - 1]是否完全相同。

思路 :

  1. 思路就是 线段树维护 hash ,有区间修改和查询 判断两段 hash值是否相同就可以了。
  2. 首先考虑一下区间合并(也就是pushup),线段树的每个节点表示这一段的 hash 值,在区间合并的操作时 大区间的 hash 值就是 左区间的 hash值 * base ^ len (len表示右区间的长度) + 右区间的 hash值 。Hash[rt] = (Hash[rt << 1] * poww[r - mid] + Hash[rt << 1 | 1])
  3. 然后是区间更新 ,把这个区间的值全部 + 1, hash 的变化 就是 base的前缀和 ,例如 某一个区间的hash值为 ∑ i = 0 n \sum_{i=0}^n i=0na[i] * base ^ i (n 为区间长度 - 1),那如果现在把每个 a[i] 都 + 1 , 那hash值的变化就是 ∑ i = 0 n \sum_{i=0}^n i=0n base ^ i , 这里用个前缀和记录一下 ,就可以很好的用 lazy维护。
  4. 查询操作和普通的查询不一样 , 因为在合并两个区间时 ,合并后的 hash 值 不是两个 hash的 简单相加(参考上面的pushup) , 也就是左区间的 hash值要先乘上 base ^ len(len为右区间长度) 再加右区间。
  5. 最后就要考虑一下溢出的问题了,如果在更新过程 某个数 >= 65536 , 就要对 65536 取模了 ,直接在更新操作里判断=肯定不好写 ,所以我们在每次更新后都找一下有没有数 大于 65536 ,这里怎么找呢 ,肯定不能暴力扫一遍 。 可以利用线段树进行一个类似二分的过程 ,维护一下每个区间的最大值 ,如果左区间最大值大于 65536 ,继续更新左区间,这样一直下去找到那个值为止 ,复杂度 log(n) ,不用担心超时。

代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef long long ll;
const int mod = 65536;
const int mod1 = 1e9 + 7;
const int mod2 = 998244353;
const int r = 31;
const int maxn = 2e6 + 7;
ll Hash[maxn],ma[maxn],la[maxn];
ll pre[maxn],poww[maxn];   // base 的幂次     前缀和
int a[maxn] ,n,q,x,y,op,L;
void pushup(int l , int r, int rt){
    
    
    int mid = (l + r) / 2;
    Hash[rt] = (Hash[rt << 1] * poww[r - mid] % mod1 + Hash[rt << 1 | 1]) % mod1;
    ma[rt] = max(ma[rt << 1] , ma[rt << 1 | 1]);
}
void pushdown(int l,int r,int rt){
    
    
    if(la[rt] == 0) return ;
    int  mid = (l + r) / 2;
    
    Hash[rt << 1] = (Hash[rt << 1] + la[rt] *  pre[mid - l] % mod1) % mod1;     //加上前缀和的 幂次
    Hash[rt << 1 | 1] = (Hash[rt << 1 | 1] + la[rt] * pre[r - mid - 1] % mod1) % mod1;

    ma[rt << 1] += la[rt];
    ma[rt << 1|1] += la[rt];


    la[rt<<1] += la[rt];
    la[rt<<1|1] += la[rt];
    la[rt] = 0;
}
void update(int L,int R,int l,int r,int rt){
    
    
    if(L <= l && R >= r){
    
    
        Hash[rt] = (Hash[rt] + pre[r - l]) % mod1;
        la[rt] ++;
        ma[rt] ++;
        return ;
	}
	pushdown(l , r, rt);
	int mid = (l + r) / 2;
	if(R > mid) update(L , R ,mid + 1 ,r ,rt << 1 | 1);
	if(L <= mid) update(L ,R ,l , mid , rt << 1);
	pushup(l ,r , rt);
}
void update_mod(int l,int r,int rt){
    
      //考虑溢出
    if(ma[rt] < mod){
    
                     //没有超过 mod的 直接退出
        return ;
    }
    if(l == r){
    
    
        ma[rt] -= mod;
        Hash[rt] -= mod;
        return ;
    }
    pushdown(l , r, rt);
    int mid = (l + r) / 2;
    if(ma[rt << 1] >= mod) update_mod( l , mid ,rt << 1);
    if(ma[rt << 1 | 1] >= mod) update_mod(mid + 1 , r, rt << 1 | 1);
    pushup(l , r, rt);

}
ll query(int L,int R,int l,int r,int rt){
    
    
	ll s = 0;
	if(L <= l && R >= r){
    
    
        return Hash[rt];
	}
	pushdown(l , r, rt);
	int mid = (l + r) / 2;
	if(R > mid)  s = (s + query(L,R,mid + 1,r,rt<<1|1) ) % mod1;
	if(L <= mid) s = (s + poww[max(0,min(R , r)- mid)] * query(L,R,l,mid,rt<<1) % mod1) % mod1;
	return s;
}
void build(int l,int r ,int rt){
    
    
    if(l == r){
    
    
        Hash[rt] = a[l];
        ma[rt] = a[l];
        return ;
    }
    int mid = (l + r) / 2;
    build(l , mid ,rt << 1);
    build(mid + 1 , r,rt << 1 | 1);
    pushup(l,r, rt);
}
int main (){
    
    
    poww[0] = pre[0] = 1;
    for(int i = 1; i <= 5e5 ; i ++){
    
    
        poww[i] = poww[i-1] * r % mod1;
    }
    for(int i = 1; i <= 5e5 ; i ++){
    
    
     pre[i] =  (pre[i-1] + poww[i]) % mod1;
    }
    scanf("%d%d",&n,&q);
    for(int i = 1; i<= n; i ++){
    
    
        scanf("%d",&a[i]);
    }
    build(1 , n,1);
    while(q--){
    
    
        scanf("%d",&op);
        if(op == 1){
    
    
            scanf("%d%d",&x,&y);
            update(x,y,1,n,1);
            update_mod(1 , n , 1);
        }
        if(op == 2){
    
    
            scanf("%d%d%d",&x,&y,&L);
            ll h1 = query(x , x + L - 1 ,1 , n, 1);
            ll h2 = query(y , y + L - 1 ,1 , n, 1);
            if(h1 == h2 ) printf ("yes\n");
            else printf ("no\n");
        }
    }
}

猜你喜欢

转载自blog.csdn.net/hddddh/article/details/109352989